mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 18:48:36 +00:00
perf: reduce per-request and per-chunk overhead across Anthropic streaming hot paths (#28289)
* perf: reduce per-request and per-chunk overhead across Anthropic streaming hot paths
- Introduce pure-text fast-path in `_build_complete_streaming_response` that collapses O(N) `content_block_delta` events into a single equivalent SSE event before conversion, eliminating per-output-token Pydantic `ModelResponseStream` construction; non-text streams (tool_use, thinking, citations) fall back to the unchanged legacy path
- Skip agentic streaming wrapper entirely when no callback overrides `async_should_run_agentic_loop`; the wrapper buffered every chunk and rebuilt the SSE response only to call hooks that all return `(False, {})` — a pure no-op for the default config
- Serialize request body once (`json.dumps`) for both the pre-call log input and the wire, instead of twice; avoids a full O(payload) scan per request, significant for long-context Claude Code histories
- Add fast path in `async_streaming_data_generator` that bypasses the per-chunk `async_post_call_streaming_hook` coroutine await, response-string materialization, and cost-injection call when no callback/guardrail/cost-injection is active (the default config)
- Resolve `_DD_STREAMING_TRACE_ENABLED` once at import time; eliminate per-chunk `NullSpan` context manager allocation when Datadog tracing is disabled (the default)
- Memoize `get_type_hints(AnthropicMessagesRequestOptionalParams)` with `@lru_cache(maxsize=1)` — resolves once per process instead of once per `/v1/messages` request (~80µs each)
- Hoist `cost_injection_active` out of the per-chunk loop in `chunk_processor`; eliminates repeated `getattr` + endpoint-type checks on every streamed byte chunk
- Extract `_build_passthrough_logging_result` from `_route_streaming_logging_to_handler` as a standalone static method to facilitate future off-loop dispatch
- Convert `async_sse_data_generator` from an `async for: yield` trampoline to a direct return of the underlying generator, removing one async-generator layer per streamed chunk
- Skip redundant `strip_empty_text_blocks_from_anthropic_messages` scan in `anthropic_messages_handler` when the async wrapper already sanitized (signalled via `_litellm_messages_presanitized` sentinel, popped before reaching provider params)
- Gate debug log `f-string` evaluation behind `isEnabledFor(DEBUG)` in both the streaming generator and the transformation layer to avoid serializing entire message payloads on every request at non-debug log levels
- Add benchmark script (`scripts/benchmark_anthropic_messages_perf.py`) with a local mock Anthropic SSE provider for reproducible TTFT and TPM measurement across commits/branches
- Add parity tests asserting fast-path and legacy-path produce byte-identical logged/billed payloads, plus unit tests for agentic hook detection, pre-serialized body reuse, and memoized key resolution
* perf: address greptile review for anthropic streaming hot path
- Bail to legacy in `_collapse_pure_text_chunks` when content_block_delta
events from different block indexes are observed without an intervening
flush. Anthropic sends blocks strictly sequentially, but defensive bail
prevents silent text-merging if the protocol ever interleaves.
- Replace leaf-class `__dict__` check for `async_post_call_streaming_hook`
in `_callback_capabilities` with a function-identity comparison that
walks the MRO. A vendor base class can carry the override and the
registered class can add nothing else; before this PR the hook was
unconditionally invoked, so an inherited-override miss would silently
drop the hook on the streaming path.
- Add unit tests for both behaviors.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* fix(mypy): narrow model_name to str in cost-injection branch
The hoisted cost_injection_active flag in chunk_processor encodes the
`bool(model_name)` requirement but mypy can't track that invariant
through the local, so the per-chunk `_process_chunk_with_cost_injection(
chunk, model_name)` calls flagged Optional[str] vs str. Pin a typed
non-None local inside the cost-injection branch so mypy narrows
correctly without changing runtime behavior.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
---------
Co-authored-by: Yassin Kortam <yassinkortam@g.ucla.edu>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -293,6 +293,12 @@ async def anthropic_messages(
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
# messages were already empty-text-block sanitized at the top of this
|
||||
# function and are NOT reassigned before this dispatch, so the handler
|
||||
# can skip its (otherwise redundant) second full-messages scan. Passed
|
||||
# explicitly (not via **kwargs) so it only affects this direct
|
||||
# dispatch -- interceptor / sync entry points still sanitize.
|
||||
_litellm_messages_presanitized=True,
|
||||
**kwargs,
|
||||
)
|
||||
ctx = contextvars.copy_context()
|
||||
@@ -351,10 +357,14 @@ def anthropic_messages_handler(
|
||||
"""
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
# Sanitize empty text blocks here too so the sync entry point
|
||||
# Sanitize empty text blocks so the sync entry point
|
||||
# (litellm.messages.create -> anthropic_messages_handler) gets the same
|
||||
# protection as the async wrapper. Idempotent when called twice.
|
||||
messages = strip_empty_text_blocks_from_anthropic_messages(messages)
|
||||
# protection as the async wrapper. The async wrapper already sanitized and
|
||||
# does not reassign messages before dispatch, so it sets
|
||||
# ``_litellm_messages_presanitized`` to skip this redundant second
|
||||
# full-messages scan. Pop it so it never leaks into provider params.
|
||||
if not kwargs.pop("_litellm_messages_presanitized", False):
|
||||
messages = strip_empty_text_blocks_from_anthropic_messages(messages)
|
||||
|
||||
metadata = validate_anthropic_api_metadata(metadata)
|
||||
|
||||
|
||||
@@ -312,7 +312,10 @@ class AnthropicMessagesConfig(BaseAnthropicMessagesConfig):
|
||||
)
|
||||
|
||||
####### get required params for all anthropic messages requests ######
|
||||
verbose_logger.debug(f"TRANSFORMATION DEBUG - Messages: {messages}")
|
||||
# Lazy %s: the f-string previously stringified the entire messages
|
||||
# payload on every request regardless of log level (a full scan of the
|
||||
# request body on the hot path). Defer it to when DEBUG is enabled.
|
||||
verbose_logger.debug("TRANSFORMATION DEBUG - Messages: %s", messages)
|
||||
|
||||
# Auto-strip advisor blocks from history if advisor tool is absent.
|
||||
# Prevents Anthropic 400: advisor_tool_result in history requires advisor tool.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, List, cast, get_type_hints
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, FrozenSet, List, cast, get_type_hints
|
||||
|
||||
from litellm.types.llms.anthropic import AnthropicMessagesRequestOptionalParams
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
@@ -6,6 +7,18 @@ from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _anthropic_messages_optional_param_keys() -> FrozenSet[str]:
|
||||
"""
|
||||
Valid AnthropicMessagesRequestOptionalParams keys.
|
||||
|
||||
``typing.get_type_hints`` is ~80us/call and this TypedDict is static, so
|
||||
resolving it once per process instead of once per request removes a fixed
|
||||
full-pass cost from the /v1/messages request-parse path.
|
||||
"""
|
||||
return frozenset(get_type_hints(AnthropicMessagesRequestOptionalParams).keys())
|
||||
|
||||
|
||||
class AnthropicMessagesRequestUtils:
|
||||
@staticmethod
|
||||
def get_requested_anthropic_messages_optional_param(
|
||||
@@ -20,7 +33,7 @@ class AnthropicMessagesRequestUtils:
|
||||
Returns:
|
||||
AnthropicMessagesRequestOptionalParams instance with only the valid parameters
|
||||
"""
|
||||
valid_keys = get_type_hints(AnthropicMessagesRequestOptionalParams).keys()
|
||||
valid_keys = _anthropic_messages_optional_param_keys()
|
||||
filtered_params = {
|
||||
k: v for k, v in params.items() if k in valid_keys and v is not None
|
||||
}
|
||||
|
||||
@@ -1886,7 +1886,9 @@ class BaseLLMHTTPHandler:
|
||||
async_httpx_client: AsyncHTTPHandler,
|
||||
request_url: str,
|
||||
headers: dict,
|
||||
signed_json_body: Optional[bytes],
|
||||
# str when the caller passes a pre-serialized (unsigned) body to avoid
|
||||
# re-dumping; bytes when a provider signed the request (e.g. Bedrock).
|
||||
signed_json_body: Optional[Union[str, bytes]],
|
||||
request_body: dict,
|
||||
stream: bool,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
@@ -2077,8 +2079,18 @@ class BaseLLMHTTPHandler:
|
||||
model=model,
|
||||
)
|
||||
|
||||
# The request body was serialized once for the pre-call log input and
|
||||
# again for the wire (json.dumps is O(payload), large for long-context
|
||||
# Claude Code history). Serialize once and reuse for both. Only when
|
||||
# the provider didn't sign the request (sign_request no-op for the
|
||||
# native anthropic path -> signed_json_body is None); signed providers
|
||||
# (e.g. Bedrock) keep their signed body untouched. The HTTP-error
|
||||
# retry path mutates + re-signs the body, so it still re-serializes
|
||||
# internally -- this only deduplicates the success path.
|
||||
request_body_json = json.dumps(request_body)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=[{"role": "user", "content": json.dumps(request_body)}],
|
||||
input=[{"role": "user", "content": request_body_json}],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_body,
|
||||
@@ -2091,7 +2103,9 @@ class BaseLLMHTTPHandler:
|
||||
async_httpx_client=async_httpx_client,
|
||||
request_url=request_url,
|
||||
headers=headers,
|
||||
signed_json_body=signed_json_body,
|
||||
signed_json_body=(
|
||||
signed_json_body if signed_json_body is not None else request_body_json
|
||||
),
|
||||
request_body=request_body,
|
||||
stream=stream or False,
|
||||
logging_obj=logging_obj,
|
||||
@@ -2113,6 +2127,14 @@ class BaseLLMHTTPHandler:
|
||||
litellm_logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if not self._has_agentic_completion_hook(logging_obj):
|
||||
# No callback overrides async_should_run_agentic_loop, so the
|
||||
# agentic wrapper's only effect would be buffering every chunk
|
||||
# and rebuilding the response from SSE at end-of-stream to call
|
||||
# hooks that all return (False, {}). Stream through directly and
|
||||
# skip that per-chunk + end-of-stream overhead.
|
||||
return completion_stream
|
||||
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.agentic_streaming_iterator import (
|
||||
AgenticAnthropicStreamingIterator,
|
||||
)
|
||||
@@ -4620,6 +4642,51 @@ class BaseLLMHTTPHandler:
|
||||
fingerprints = list(kwargs.get("_agentic_loop_fingerprints", []) or [])
|
||||
return depth, max(max_loops, 1), fingerprints
|
||||
|
||||
@staticmethod
|
||||
def _has_agentic_completion_hook(logging_obj: Any) -> bool:
|
||||
"""
|
||||
True if any registered callback actually overrides
|
||||
``async_should_run_agentic_loop`` (the gate every agentic hook goes
|
||||
through). The base ``CustomLogger`` implementation returns
|
||||
``(False, {})``, so when nothing overrides it the agentic
|
||||
post-processing is a guaranteed no-op and the streaming wrapper that
|
||||
buffers + rebuilds the whole response from SSE just to call it can be
|
||||
skipped entirely.
|
||||
|
||||
Function-identity comparison (not a leaf ``__dict__`` check) so an
|
||||
override inherited through any intermediate class is still detected --
|
||||
a false negative here would silently disable agentic features.
|
||||
|
||||
String entries in ``litellm.callbacks`` (e.g. ``"datadog"``) are
|
||||
resolved to their ``CustomLogger`` instance via
|
||||
``get_custom_logger_compatible_class`` -- same pattern as
|
||||
``ProxyLogging._callback_capabilities`` -- so a string-registered
|
||||
agentic callback is detected too.
|
||||
"""
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_custom_logger_compatible_class,
|
||||
)
|
||||
|
||||
base_func = CustomLogger.async_should_run_agentic_loop
|
||||
callbacks = litellm.callbacks + (
|
||||
getattr(logging_obj, "dynamic_success_callbacks", None) or []
|
||||
)
|
||||
for cb in callbacks:
|
||||
if isinstance(cb, str):
|
||||
resolved = get_custom_logger_compatible_class(cb) # type: ignore[arg-type]
|
||||
if resolved is None:
|
||||
continue
|
||||
cb = resolved
|
||||
if not isinstance(cb, CustomLogger):
|
||||
continue
|
||||
cb_func = getattr(type(cb), "async_should_run_agentic_loop", base_func)
|
||||
if getattr(cb_func, "__func__", cb_func) is not getattr(
|
||||
base_func, "__func__", base_func
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _check_agentic_loop_safety(
|
||||
tool_calls: Any,
|
||||
|
||||
@@ -32,7 +32,7 @@ from litellm.constants import (
|
||||
STREAM_SSE_DATA_PREFIX,
|
||||
)
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||
from litellm.litellm_core_utils.dd_tracing import NullTracer, tracer
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
|
||||
get_response_headers,
|
||||
@@ -65,6 +65,13 @@ else:
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.types.utils import ModelResponse, ModelResponseStream, Usage
|
||||
|
||||
# Datadog streaming spans are a no-op when ddtrace is not enabled, but the
|
||||
# ``with tracer.trace(...)`` context manager still allocates a NullSpan and
|
||||
# runs __enter__/__exit__ for every streamed chunk. Resolve once at import so
|
||||
# the per-chunk hot path can skip the context manager entirely when tracing
|
||||
# is off (the default).
|
||||
_DD_STREAMING_TRACE_ENABLED = not isinstance(tracer, NullTracer)
|
||||
|
||||
|
||||
def _serialize_http_exception_detail(
|
||||
detail: Any,
|
||||
@@ -231,7 +238,7 @@ def _extract_error_from_sse_chunk(event_line: Union[str, bytes]) -> dict:
|
||||
return default_error
|
||||
|
||||
|
||||
async def create_response(
|
||||
async def create_response( # noqa: PLR0915
|
||||
generator: AsyncGenerator[str, None],
|
||||
media_type: str,
|
||||
headers: dict,
|
||||
@@ -336,6 +343,13 @@ async def create_response(
|
||||
)
|
||||
|
||||
async def combined_generator() -> AsyncGenerator[str, None]:
|
||||
if not _DD_STREAMING_TRACE_ENABLED:
|
||||
# Fast path: no per-chunk span object / context-manager overhead.
|
||||
if first_chunk_value is not None:
|
||||
yield first_chunk_value
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
return
|
||||
if first_chunk_value is not None:
|
||||
with tracer.trace(DD_TRACER_STREAMING_CHUNK_YIELD_RESOURCE):
|
||||
yield first_chunk_value
|
||||
@@ -1900,6 +1914,23 @@ class ProxyBaseLLMRequestProcessing:
|
||||
failure hook and yields via serialize_error. Use for SSE or NDJSON.
|
||||
"""
|
||||
verbose_proxy_logger.debug("inside generator")
|
||||
# Resolve per-stream (not per-chunk) whether the heavy per-chunk path
|
||||
# is needed. When no callback overrides ``async_post_call_streaming_hook``,
|
||||
# no CustomGuardrail is active, and cost injection is disabled, the
|
||||
# per-chunk hook returns the chunk unchanged, ``str_so_far`` is never
|
||||
# consumed, and cost injection is a no-op -- so the per-chunk coroutine
|
||||
# await, response-string materialization, and cost-injection call are
|
||||
# pure overhead on the streaming hot path (the default config).
|
||||
caps = ProxyLogging._callback_capabilities()
|
||||
cost_injection_enabled = bool(
|
||||
getattr(litellm, "include_cost_in_streaming_usage", False)
|
||||
)
|
||||
fast_path = (
|
||||
not caps.has_streaming_chunk_override
|
||||
and not caps.has_guardrail
|
||||
and not cost_injection_enabled
|
||||
)
|
||||
debug_enabled = verbose_proxy_logger.isEnabledFor(logging.DEBUG)
|
||||
try:
|
||||
str_so_far = ""
|
||||
async for (
|
||||
@@ -1909,9 +1940,17 @@ class ProxyBaseLLMRequestProcessing:
|
||||
response=response,
|
||||
request_data=request_data,
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
"async_data_generator: received streaming chunk - {}".format(chunk)
|
||||
)
|
||||
# ``.format(chunk)`` was previously evaluated for every chunk
|
||||
# regardless of log level; gate it behind the level check.
|
||||
if debug_enabled:
|
||||
verbose_proxy_logger.debug(
|
||||
"async_data_generator: received streaming chunk - %s", chunk
|
||||
)
|
||||
|
||||
if fast_path:
|
||||
yield serialize_chunk(chunk)
|
||||
continue
|
||||
|
||||
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
response=chunk,
|
||||
@@ -1969,7 +2008,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||
yield serialize_error(proxy_exception)
|
||||
|
||||
@staticmethod
|
||||
async def async_sse_data_generator(
|
||||
def async_sse_data_generator(
|
||||
response: Any,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
request_data: dict,
|
||||
@@ -1977,17 +2016,20 @@ class ProxyBaseLLMRequestProcessing:
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Anthropic /messages and Google /generateContent streaming data generator require SSE events.
|
||||
Delegates to async_streaming_data_generator with SSE serializers.
|
||||
|
||||
Returns the underlying ``async_streaming_data_generator`` configured with
|
||||
SSE serializers directly (rather than re-wrapping it in another
|
||||
``async for: yield`` trampoline), so a streamed chunk traverses one
|
||||
fewer async-generator layer / coroutine resume on the hot path.
|
||||
"""
|
||||
async for chunk in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
|
||||
return ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=request_data,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
serialize_chunk=ProxyBaseLLMRequestProcessing.return_sse_chunk,
|
||||
serialize_error=lambda proxy_exc: f"{STREAM_SSE_DATA_PREFIX}{json.dumps({'error': proxy_exc.to_dict()})}\n\n",
|
||||
):
|
||||
yield chunk
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_chunk_with_cost_injection(chunk: Any, model_name: str) -> Any:
|
||||
|
||||
+168
-1
@@ -266,7 +266,174 @@ class AnthropicPassthroughLoggingHandler:
|
||||
model: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
"""
|
||||
Builds complete response from raw Anthropic chunks
|
||||
Builds complete response from raw Anthropic chunks.
|
||||
|
||||
Fast path: for the dominant case of a pure-text streaming response
|
||||
(no tool_use / thinking / non-text content blocks), the long run of
|
||||
``content_block_delta`` text deltas is collapsed into a single
|
||||
equivalent SSE event before conversion. ``chunk_parser`` and
|
||||
``stream_chunk_builder`` remain the single source of truth for chunk
|
||||
shape, usage math and finish-reason mapping, so the rebuilt response
|
||||
(and therefore the logged/billed payload) is identical -- this is
|
||||
asserted by a parity test. Anything non-trivial falls back to the
|
||||
unchanged legacy reconstruction.
|
||||
|
||||
Per-event Pydantic ``ModelResponseStream`` construction dominated
|
||||
event-loop CPU under concurrent streaming; collapsing the homogeneous
|
||||
text run removes O(num_output_tokens) of it.
|
||||
"""
|
||||
collapsed = AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks(
|
||||
all_chunks
|
||||
)
|
||||
if collapsed is not None:
|
||||
return AnthropicPassthroughLoggingHandler._build_complete_streaming_response_legacy(
|
||||
all_chunks=collapsed,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
model=model,
|
||||
)
|
||||
return AnthropicPassthroughLoggingHandler._build_complete_streaming_response_legacy(
|
||||
all_chunks=all_chunks,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Anthropic SSE block/delta types that the fast path is NOT allowed to
|
||||
# collapse -- their presence forces the unchanged legacy path so tool
|
||||
# calls, thinking, citations, etc. keep byte-identical reconstruction.
|
||||
_FAST_PATH_DISALLOWED_DELTA_TYPES = frozenset(
|
||||
{
|
||||
"input_json_delta",
|
||||
"thinking_delta",
|
||||
"signature_delta",
|
||||
"citations_delta",
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _collapse_pure_text_chunks( # noqa: PLR0915
|
||||
all_chunks: Sequence[Union[str, bytes]],
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Return a new chunk list with the contiguous run of text-only
|
||||
``content_block_delta`` events replaced by a single equivalent event,
|
||||
or ``None`` if the stream is not a pure single-text-block response
|
||||
(in which case the caller uses the legacy path unchanged).
|
||||
|
||||
Only ``message_start`` / ``content_block_start(text)`` /
|
||||
``content_block_delta(text_delta)`` / ``content_block_stop`` /
|
||||
``message_delta`` / ``message_stop`` / ``ping`` events are accepted.
|
||||
Any other content-block type or delta type returns ``None``.
|
||||
"""
|
||||
normalized: List[str] = []
|
||||
for raw in all_chunks:
|
||||
line = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
||||
for ev in line.split("\n\n"):
|
||||
ev = ev.strip()
|
||||
if ev:
|
||||
normalized.append(ev)
|
||||
|
||||
text_block_indexes: set = set()
|
||||
out: List[str] = []
|
||||
pending_text: List[str] = []
|
||||
pending_index: Optional[int] = None
|
||||
saw_any_text_delta = False
|
||||
|
||||
def flush() -> None:
|
||||
nonlocal pending_text, pending_index
|
||||
if pending_text:
|
||||
merged = {
|
||||
"type": "content_block_delta",
|
||||
"index": pending_index if pending_index is not None else 0,
|
||||
"delta": {"type": "text_delta", "text": "".join(pending_text)},
|
||||
}
|
||||
out.append("data: " + json.dumps(merged))
|
||||
pending_text = []
|
||||
pending_index = None
|
||||
|
||||
for ev in normalized:
|
||||
idx = ev.find("data:")
|
||||
if idx == -1:
|
||||
# Bare "event: <name>" line. The legacy converter turns this
|
||||
# into an empty ModelResponseStream that contributes nothing
|
||||
# to stream_chunk_builder. Drop the high-frequency interior
|
||||
# markers (content_block_delta / ping); keep every other
|
||||
# bare event line verbatim so chunk ordering and the
|
||||
# load-bearing chunks[0] (event: message_start) are retained.
|
||||
name = ev[len("event:") :].strip() if ev.startswith("event:") else ""
|
||||
if name in ("content_block_delta", "ping"):
|
||||
continue
|
||||
flush()
|
||||
out.append(ev)
|
||||
continue
|
||||
|
||||
json_str = ev[idx + len("data:") :].strip()
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return None
|
||||
|
||||
etype = data.get("type")
|
||||
if etype == "content_block_start":
|
||||
block = data.get("content_block") or {}
|
||||
if block.get("type") != "text":
|
||||
return None
|
||||
text_block_indexes.add(data.get("index"))
|
||||
flush()
|
||||
out.append(ev)
|
||||
elif etype == "content_block_delta":
|
||||
delta = data.get("delta") or {}
|
||||
dtype = delta.get("type")
|
||||
if (
|
||||
dtype
|
||||
in AnthropicPassthroughLoggingHandler._FAST_PATH_DISALLOWED_DELTA_TYPES
|
||||
):
|
||||
return None
|
||||
if dtype != "text_delta":
|
||||
return None
|
||||
cur_index = data.get("index")
|
||||
if cur_index not in text_block_indexes:
|
||||
return None
|
||||
# Defensive: Anthropic sends blocks strictly sequentially
|
||||
# (start/deltas/stop, then next block), so pending_text from
|
||||
# block N must be flushed by content_block_stop before block
|
||||
# N+1's deltas arrive. If we ever see a delta whose index
|
||||
# disagrees with the current pending buffer, the stream is
|
||||
# interleaved -- fall back to legacy rather than risk merging
|
||||
# text from different blocks under a single index.
|
||||
if (
|
||||
pending_text
|
||||
and pending_index is not None
|
||||
and cur_index != pending_index
|
||||
):
|
||||
return None
|
||||
saw_any_text_delta = True
|
||||
pending_index = cur_index
|
||||
pending_text.append(delta.get("text") or "")
|
||||
elif etype == "ping":
|
||||
# Interior no-op; legacy maps it to an empty chunk.
|
||||
continue
|
||||
else:
|
||||
# message_start / content_block_stop / message_delta /
|
||||
# message_stop / error: pass through unchanged.
|
||||
flush()
|
||||
out.append(ev)
|
||||
|
||||
flush()
|
||||
|
||||
if not saw_any_text_delta:
|
||||
return None
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _build_complete_streaming_response_legacy(
|
||||
all_chunks: Sequence[Union[str, bytes]],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
"""
|
||||
Original reconstruction: convert every SSE event to a generic chunk
|
||||
and assemble via stream_chunk_builder. Kept verbatim as the fallback
|
||||
/ source of truth for the fast path's parity test.
|
||||
|
||||
- Splits multi-event chunks into individual SSE events
|
||||
- Converts str chunks to generic chunks
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -45,28 +45,44 @@ class PassThroughStreamingHandler:
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
# Resolve once per stream rather than re-reading the global +
|
||||
# re-branching on every chunk. ``include_cost_in_streaming_usage`` is
|
||||
# set at config load and stable for the process, matching how the
|
||||
# proxy-level streaming fast path resolves it.
|
||||
cost_injection_active = (
|
||||
bool(getattr(litellm, "include_cost_in_streaming_usage", False))
|
||||
and bool(model_name)
|
||||
and endpoint_type in (EndpointType.VERTEX_AI, EndpointType.ANTHROPIC)
|
||||
)
|
||||
try:
|
||||
async for chunk in response.aiter_bytes():
|
||||
raw_bytes.append(chunk)
|
||||
if (
|
||||
getattr(litellm, "include_cost_in_streaming_usage", False)
|
||||
and model_name
|
||||
):
|
||||
if not cost_injection_active:
|
||||
# Hot path: just buffer for end-of-stream logging and forward.
|
||||
async for chunk in response.aiter_bytes():
|
||||
raw_bytes.append(chunk)
|
||||
yield chunk
|
||||
else:
|
||||
# ``cost_injection_active`` already requires ``model_name`` to
|
||||
# be truthy; pin to a typed local so mypy narrows ``Optional[str]``
|
||||
# -> ``str`` for the per-chunk call site.
|
||||
assert model_name is not None
|
||||
resolved_model_name: str = model_name
|
||||
async for chunk in response.aiter_bytes():
|
||||
raw_bytes.append(chunk)
|
||||
if endpoint_type == EndpointType.VERTEX_AI:
|
||||
if "streamRawPredict" in url_route or "rawPredict" in url_route:
|
||||
modified_chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection(
|
||||
chunk, model_name
|
||||
chunk, resolved_model_name
|
||||
)
|
||||
if modified_chunk is not None:
|
||||
chunk = modified_chunk
|
||||
elif endpoint_type == EndpointType.ANTHROPIC:
|
||||
else: # EndpointType.ANTHROPIC
|
||||
modified_chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection(
|
||||
chunk, model_name
|
||||
chunk, resolved_model_name
|
||||
)
|
||||
if modified_chunk is not None:
|
||||
chunk = modified_chunk
|
||||
|
||||
yield chunk
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
|
||||
raise
|
||||
@@ -115,64 +131,20 @@ class PassThroughStreamingHandler:
|
||||
- OpenAI
|
||||
"""
|
||||
try:
|
||||
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
|
||||
raw_bytes
|
||||
(
|
||||
standard_logging_response_object,
|
||||
kwargs,
|
||||
) = PassThroughStreamingHandler._build_passthrough_logging_result(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
raw_bytes=raw_bytes,
|
||||
end_time=end_time,
|
||||
model=model,
|
||||
)
|
||||
standard_logging_response_object: Optional[
|
||||
PassThroughEndpointLoggingResultValues
|
||||
] = None
|
||||
kwargs: dict = {}
|
||||
if endpoint_type == EndpointType.ANTHROPIC:
|
||||
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
anthropic_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||
elif endpoint_type == EndpointType.VERTEX_AI:
|
||||
vertex_passthrough_logging_handler_result = VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
model=model,
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
vertex_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
||||
elif endpoint_type == EndpointType.OPENAI:
|
||||
openai_passthrough_logging_handler_result = OpenAIPassthroughLoggingHandler._handle_logging_openai_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
openai_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = openai_passthrough_logging_handler_result["kwargs"]
|
||||
|
||||
if standard_logging_response_object is None:
|
||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
|
||||
)
|
||||
await litellm_logging_obj.async_success_handler(
|
||||
result=standard_logging_response_object,
|
||||
start_time=start_time,
|
||||
@@ -199,6 +171,89 @@ class PassThroughStreamingHandler:
|
||||
f"Error in _route_streaming_logging_to_handler: {str(e)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_passthrough_logging_result(
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
request_body: dict,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
raw_bytes: List[bytes],
|
||||
end_time: datetime,
|
||||
model: Optional[str],
|
||||
) -> Tuple[PassThroughEndpointLoggingResultValues, dict]:
|
||||
"""
|
||||
Synchronous, CPU-bound reconstruction of the standard logging payload
|
||||
from collected raw SSE bytes. Extracted from
|
||||
_route_streaming_logging_to_handler so the per-endpoint dispatch can
|
||||
be unit-tested in isolation. Still invoked synchronously on the event
|
||||
loop; an off-loop dispatch is a future change, not part of this PR.
|
||||
"""
|
||||
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
|
||||
raw_bytes
|
||||
)
|
||||
standard_logging_response_object: Optional[
|
||||
PassThroughEndpointLoggingResultValues
|
||||
] = None
|
||||
kwargs: dict = {}
|
||||
if endpoint_type == EndpointType.ANTHROPIC:
|
||||
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
anthropic_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||
elif endpoint_type == EndpointType.VERTEX_AI:
|
||||
vertex_passthrough_logging_handler_result = (
|
||||
VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
vertex_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
||||
elif endpoint_type == EndpointType.OPENAI:
|
||||
openai_passthrough_logging_handler_result = (
|
||||
OpenAIPassthroughLoggingHandler._handle_logging_openai_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
openai_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = openai_passthrough_logging_handler_result["kwargs"]
|
||||
|
||||
if standard_logging_response_object is None:
|
||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
|
||||
)
|
||||
return standard_logging_response_object, kwargs
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_for_cost_injection(
|
||||
request_body: Optional[dict],
|
||||
|
||||
+16
-1
@@ -1596,7 +1596,22 @@ class ProxyLogging:
|
||||
iterator_overrides.append((resolved, "override"))
|
||||
elif "apply_guardrail" in cls_attrs:
|
||||
iterator_overrides.append((resolved, "apply_guardrail"))
|
||||
if "async_post_call_streaming_hook" in cls_attrs:
|
||||
# Walk the MRO for ``async_post_call_streaming_hook`` rather than
|
||||
# using the leaf-class ``__dict__`` check used by the other flags:
|
||||
# before this PR the hook was unconditionally invoked, so a
|
||||
# callback that inherits an override from an intermediate parent
|
||||
# (e.g. a vendor base class providing the override, with the
|
||||
# registered class adding nothing else) MUST still be detected.
|
||||
# A leaf-class miss here would silently drop the inherited hook.
|
||||
base_streaming_hook = CustomLogger.async_post_call_streaming_hook
|
||||
cls_streaming_hook = getattr(
|
||||
cls,
|
||||
"async_post_call_streaming_hook",
|
||||
base_streaming_hook,
|
||||
)
|
||||
if getattr(
|
||||
cls_streaming_hook, "__func__", cls_streaming_hook
|
||||
) is not getattr(base_streaming_hook, "__func__", base_streaming_hook):
|
||||
has_streaming_chunk_override = True
|
||||
if "async_pre_call_hook" in cls_attrs:
|
||||
has_pre_call_override = True
|
||||
|
||||
@@ -0,0 +1,624 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Benchmark LiteLLM proxy /v1/messages (Anthropic Messages API) streaming.
|
||||
|
||||
Measures the two metrics that matter for an interactive streaming proxy:
|
||||
|
||||
* TTFT - time to first streamed token (first ``content_block_delta``)
|
||||
* TPM - sustained output token throughput (tokens / second) once the
|
||||
full stream is consumed, plus request throughput (RPS)
|
||||
|
||||
It boots a local mock Anthropic provider that speaks the real Anthropic
|
||||
streaming SSE wire format (``message_start`` -> ``content_block_delta`` ->
|
||||
``message_stop``) and a LiteLLM proxy from any checkout, so commits/branches
|
||||
can be compared without depending on real provider latency.
|
||||
|
||||
Example:
|
||||
uv run python scripts/benchmark_anthropic_messages_perf.py \
|
||||
--label baseline --proxy-command ".venv/bin/litellm"
|
||||
|
||||
Compare an already-running proxy:
|
||||
uv run python scripts/benchmark_anthropic_messages_perf.py \
|
||||
--no-start-proxy --label current
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import signal
|
||||
import statistics
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
DEFAULT_MODEL = "claude-perf-test"
|
||||
DEFAULT_API_KEY = "sk-1234"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamSample:
|
||||
success: bool
|
||||
ttft_ms: float
|
||||
total_ms: float
|
||||
output_tokens: int
|
||||
status_code: int
|
||||
error: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummaryStats:
|
||||
requests: int
|
||||
failures: int
|
||||
rps: float
|
||||
ttft_mean_ms: float
|
||||
ttft_p50_ms: float
|
||||
ttft_p95_ms: float
|
||||
ttft_p99_ms: float
|
||||
total_p50_ms: float
|
||||
total_p95_ms: float
|
||||
tokens_per_sec: float
|
||||
|
||||
|
||||
class MockAnthropicProvider:
|
||||
"""Minimal Anthropic Messages API server (real streaming SSE format)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
first_token_delay_ms: float,
|
||||
stream_content_chunks: int,
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.first_token_delay_ms = first_token_delay_ms
|
||||
self.stream_content_chunks = stream_content_chunks
|
||||
self.runner: Optional[web.AppRunner] = None
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
async def start(self) -> None:
|
||||
app = web.Application()
|
||||
app.router.add_post("/v1/messages", self.handle_messages)
|
||||
self.runner = web.AppRunner(app, access_log=None)
|
||||
await self.runner.setup()
|
||||
site = web.TCPSite(self.runner, self.host, self.port)
|
||||
await site.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self.runner is not None:
|
||||
await self.runner.cleanup()
|
||||
|
||||
async def handle_messages(self, request: web.Request) -> web.StreamResponse:
|
||||
body = await request.json()
|
||||
if body.get("stream"):
|
||||
return await self._streaming_response(request, body)
|
||||
return self._json_response(body)
|
||||
|
||||
def _json_response(self, body: dict[str, Any]) -> web.Response:
|
||||
payload = {
|
||||
"id": "msg_perf",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": body.get("model", DEFAULT_MODEL),
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 8, "output_tokens": 1},
|
||||
}
|
||||
return web.json_response(payload)
|
||||
|
||||
@staticmethod
|
||||
def _sse(event: str, data: dict[str, Any]) -> bytes:
|
||||
return f"event: {event}\ndata: {json.dumps(data)}\n\n".encode()
|
||||
|
||||
async def _streaming_response(
|
||||
self, request: web.Request, body: dict[str, Any]
|
||||
) -> web.StreamResponse:
|
||||
model = body.get("model", DEFAULT_MODEL)
|
||||
response = web.StreamResponse(
|
||||
status=200,
|
||||
headers={
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
},
|
||||
)
|
||||
await response.prepare(request)
|
||||
|
||||
await response.write(
|
||||
self._sse(
|
||||
"message_start",
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_perf",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": [],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 8, "output_tokens": 0},
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
await response.write(
|
||||
self._sse(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if self.first_token_delay_ms > 0:
|
||||
await asyncio.sleep(self.first_token_delay_ms / 1000)
|
||||
|
||||
for _ in range(self.stream_content_chunks):
|
||||
await response.write(
|
||||
self._sse(
|
||||
"content_block_delta",
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": "hello "},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
await response.write(
|
||||
self._sse("content_block_stop", {"type": "content_block_stop", "index": 0})
|
||||
)
|
||||
await response.write(
|
||||
self._sse(
|
||||
"message_delta",
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
|
||||
"usage": {"output_tokens": self.stream_content_chunks},
|
||||
},
|
||||
)
|
||||
)
|
||||
await response.write(self._sse("message_stop", {"type": "message_stop"}))
|
||||
await response.write_eof()
|
||||
return response
|
||||
|
||||
|
||||
def percentile(values: list[float], pct: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
sorted_values = sorted(values)
|
||||
index = min(int(len(sorted_values) * pct / 100), len(sorted_values) - 1)
|
||||
return sorted_values[index]
|
||||
|
||||
|
||||
def summarize(samples: list[StreamSample], wall_time_s: float) -> SummaryStats:
|
||||
ok = [s for s in samples if s.success]
|
||||
ttfts = [s.ttft_ms for s in ok]
|
||||
totals = [s.total_ms for s in ok]
|
||||
total_tokens = sum(s.output_tokens for s in ok)
|
||||
return SummaryStats(
|
||||
requests=len(samples),
|
||||
failures=len(samples) - len(ok),
|
||||
rps=(len(ok) / wall_time_s) if wall_time_s > 0 else 0.0,
|
||||
ttft_mean_ms=statistics.mean(ttfts) if ttfts else 0.0,
|
||||
ttft_p50_ms=percentile(ttfts, 50),
|
||||
ttft_p95_ms=percentile(ttfts, 95),
|
||||
ttft_p99_ms=percentile(ttfts, 99),
|
||||
total_p50_ms=percentile(totals, 50),
|
||||
total_p95_ms=percentile(totals, 95),
|
||||
# Aggregate output-token throughput: total tokens delivered across all
|
||||
# successful requests divided by wall-clock time. This is the true
|
||||
# server TPM and (unlike tokens / summed-per-request-latency) scales
|
||||
# correctly with concurrency.
|
||||
tokens_per_sec=(total_tokens / wall_time_s) if wall_time_s > 0 else 0.0,
|
||||
)
|
||||
|
||||
|
||||
def get_git_revision(litellm_dir: Path) -> str:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
cwd=litellm_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def write_proxy_config(config_path: Path, provider_base_url: str, api_key: str) -> None:
|
||||
config_path.write_text(
|
||||
f"""model_list:
|
||||
- model_name: {DEFAULT_MODEL}
|
||||
litellm_params:
|
||||
model: anthropic/{DEFAULT_MODEL}
|
||||
api_key: fake-provider-key
|
||||
api_base: {provider_base_url}
|
||||
|
||||
general_settings:
|
||||
master_key: {api_key}
|
||||
|
||||
litellm_settings:
|
||||
telemetry: false
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
async def wait_for_proxy(base_url: str, timeout_s: float) -> None:
|
||||
deadline = time.perf_counter() + timeout_s
|
||||
last_error = ""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while time.perf_counter() < deadline:
|
||||
try:
|
||||
async with session.get(f"{base_url}/health/liveliness") as response:
|
||||
if response.status < 500:
|
||||
return
|
||||
last_error = f"HTTP {response.status}"
|
||||
except Exception as exc:
|
||||
last_error = str(exc)
|
||||
await asyncio.sleep(0.5)
|
||||
raise TimeoutError(f"Timed out waiting for proxy at {base_url}: {last_error}")
|
||||
|
||||
|
||||
def start_proxy_process(
|
||||
litellm_dir: Path,
|
||||
proxy_command: str,
|
||||
config_path: Path,
|
||||
port: int,
|
||||
log_path: Path,
|
||||
) -> subprocess.Popen:
|
||||
command = shlex.split(proxy_command) + [
|
||||
"--config",
|
||||
str(config_path),
|
||||
"--port",
|
||||
str(port),
|
||||
]
|
||||
env = {
|
||||
**os.environ,
|
||||
"LITELLM_TELEMETRY": "False",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
}
|
||||
log_file = log_path.open("w", encoding="utf-8")
|
||||
return subprocess.Popen(
|
||||
command,
|
||||
cwd=litellm_dir,
|
||||
env=env,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True,
|
||||
)
|
||||
|
||||
|
||||
def stop_proxy_process(process: subprocess.Popen) -> None:
|
||||
if process.poll() is not None:
|
||||
return
|
||||
try:
|
||||
os.killpg(process.pid, signal.SIGTERM)
|
||||
process.wait(timeout=10)
|
||||
except Exception:
|
||||
try:
|
||||
os.killpg(process.pid, signal.SIGKILL)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def measure_stream(
|
||||
session: aiohttp.ClientSession,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
payload: dict[str, Any],
|
||||
) -> StreamSample:
|
||||
start = time.perf_counter()
|
||||
ttft_ms = 0.0
|
||||
output_tokens = 0
|
||||
try:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
body = await response.read()
|
||||
return StreamSample(
|
||||
success=False,
|
||||
ttft_ms=0.0,
|
||||
total_ms=(time.perf_counter() - start) * 1000,
|
||||
output_tokens=0,
|
||||
status_code=response.status,
|
||||
error=body.decode("utf-8", errors="ignore")[:200],
|
||||
)
|
||||
async for raw_line in response.content:
|
||||
line = raw_line.strip()
|
||||
if not line.startswith(b"data:"):
|
||||
continue
|
||||
data = line[5:].strip()
|
||||
if data == b"[DONE]":
|
||||
break
|
||||
try:
|
||||
event = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
etype = event.get("type")
|
||||
if etype == "content_block_delta":
|
||||
if ttft_ms == 0.0:
|
||||
ttft_ms = (time.perf_counter() - start) * 1000
|
||||
output_tokens += 1
|
||||
elif etype == "message_stop":
|
||||
break
|
||||
total_ms = (time.perf_counter() - start) * 1000
|
||||
if ttft_ms == 0.0:
|
||||
return StreamSample(
|
||||
success=False,
|
||||
ttft_ms=0.0,
|
||||
total_ms=total_ms,
|
||||
output_tokens=0,
|
||||
status_code=response.status,
|
||||
error="stream ended before a content token",
|
||||
)
|
||||
return StreamSample(
|
||||
success=True,
|
||||
ttft_ms=ttft_ms,
|
||||
total_ms=total_ms,
|
||||
output_tokens=output_tokens,
|
||||
status_code=response.status,
|
||||
)
|
||||
except Exception as exc:
|
||||
return StreamSample(
|
||||
success=False,
|
||||
ttft_ms=0.0,
|
||||
total_ms=(time.perf_counter() - start) * 1000,
|
||||
output_tokens=0,
|
||||
status_code=0,
|
||||
error=str(exc)[:200],
|
||||
)
|
||||
|
||||
|
||||
async def run_benchmark(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
payload: dict[str, Any],
|
||||
requests: int,
|
||||
concurrency: int,
|
||||
warmup: int,
|
||||
timeout_s: float,
|
||||
) -> SummaryStats:
|
||||
timeout = aiohttp.ClientTimeout(total=timeout_s)
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=max(concurrency * 2, 10),
|
||||
limit_per_host=max(concurrency, 10),
|
||||
force_close=False,
|
||||
)
|
||||
|
||||
async def worker(
|
||||
session: aiohttp.ClientSession,
|
||||
counter: list[int],
|
||||
budget: int,
|
||||
sink: list[StreamSample],
|
||||
) -> None:
|
||||
# Steady-state load: exactly `concurrency` workers, each pulling the
|
||||
# next request slot as soon as its previous one finishes. Keeps
|
||||
# in-flight concurrency constant (vs. a gather-all + semaphore burst)
|
||||
# which removes the thundering-herd variance that otherwise swamps a
|
||||
# 10% signal.
|
||||
while True:
|
||||
idx = counter[0]
|
||||
if idx >= budget:
|
||||
return
|
||||
counter[0] = idx + 1
|
||||
sink.append(await measure_stream(session, url, headers, payload))
|
||||
|
||||
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||
if warmup > 0:
|
||||
wcounter = [0]
|
||||
await asyncio.gather(
|
||||
*[worker(session, wcounter, warmup, []) for _ in range(concurrency)]
|
||||
)
|
||||
samples: list[StreamSample] = []
|
||||
counter = [0]
|
||||
wall_start = time.perf_counter()
|
||||
await asyncio.gather(
|
||||
*[worker(session, counter, requests, samples) for _ in range(concurrency)]
|
||||
)
|
||||
wall_time_s = time.perf_counter() - wall_start
|
||||
return summarize(samples, wall_time_s)
|
||||
|
||||
|
||||
def stats_to_dict(stats: SummaryStats) -> dict[str, Any]:
|
||||
return {
|
||||
"requests": stats.requests,
|
||||
"failures": stats.failures,
|
||||
"rps": stats.rps,
|
||||
"ttft_mean_ms": stats.ttft_mean_ms,
|
||||
"ttft_p50_ms": stats.ttft_p50_ms,
|
||||
"ttft_p95_ms": stats.ttft_p95_ms,
|
||||
"ttft_p99_ms": stats.ttft_p99_ms,
|
||||
"total_p50_ms": stats.total_p50_ms,
|
||||
"total_p95_ms": stats.total_p95_ms,
|
||||
"tokens_per_sec": stats.tokens_per_sec,
|
||||
}
|
||||
|
||||
|
||||
def print_summary(label: str, revision: str, stats: SummaryStats) -> None:
|
||||
print("\n=== Anthropic /v1/messages streaming benchmark ===")
|
||||
print(f"Label: {label}")
|
||||
print(f"Revision: {revision}")
|
||||
print(f"Requests: {stats.requests} Failures: {stats.failures}")
|
||||
print(f"TTFT mean: {stats.ttft_mean_ms:.2f} ms")
|
||||
print(f"TTFT p50: {stats.ttft_p50_ms:.2f} ms")
|
||||
print(f"TTFT p95: {stats.ttft_p95_ms:.2f} ms")
|
||||
print(f"TTFT p99: {stats.ttft_p99_ms:.2f} ms")
|
||||
print(f"Full p50: {stats.total_p50_ms:.2f} ms")
|
||||
print(f"Full p95: {stats.total_p95_ms:.2f} ms")
|
||||
print(f"Throughput: {stats.rps:.2f} req/s")
|
||||
print(f"TPM: {stats.tokens_per_sec:.1f} output tokens/s")
|
||||
print("\nMarkdown row:")
|
||||
print(
|
||||
"| "
|
||||
+ " | ".join(
|
||||
[
|
||||
label,
|
||||
revision,
|
||||
f"{stats.ttft_p50_ms:.2f}",
|
||||
f"{stats.ttft_p95_ms:.2f}",
|
||||
f"{stats.tokens_per_sec:.1f}",
|
||||
f"{stats.rps:.2f}",
|
||||
]
|
||||
)
|
||||
+ " |"
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--label", default="current")
|
||||
parser.add_argument("--litellm-dir", default=str(Path.cwd()))
|
||||
parser.add_argument("--proxy-command", default="uv run litellm")
|
||||
parser.add_argument("--proxy-host", default="127.0.0.1")
|
||||
parser.add_argument("--proxy-port", type=int, default=4000)
|
||||
parser.add_argument("--provider-host", default="127.0.0.1")
|
||||
parser.add_argument("--provider-port", type=int, default=8098)
|
||||
parser.add_argument("--api-key", default=DEFAULT_API_KEY)
|
||||
parser.add_argument("--requests", type=int, default=300)
|
||||
parser.add_argument("--concurrency", type=int, default=20)
|
||||
parser.add_argument("--warmup", type=int, default=30)
|
||||
parser.add_argument("--timeout", type=float, default=30)
|
||||
parser.add_argument("--proxy-start-timeout", type=float, default=90)
|
||||
parser.add_argument("--provider-first-token-delay-ms", type=float, default=0)
|
||||
parser.add_argument(
|
||||
"--provider-stream-content-chunks",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of text delta chunks the mock emits (default 64).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeats",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Run the suite N times against the same proxy; report the median run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-start-proxy",
|
||||
action="store_true",
|
||||
help="Benchmark an already-running proxy at --proxy-host/--proxy-port",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider-url",
|
||||
help="Use an already-running Anthropic-compatible provider",
|
||||
)
|
||||
parser.add_argument("--output-json", help="Write machine-readable results")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def async_main() -> None:
|
||||
args = parse_args()
|
||||
litellm_dir = Path(args.litellm_dir).resolve()
|
||||
revision = get_git_revision(litellm_dir)
|
||||
proxy_base_url = f"http://{args.proxy_host}:{args.proxy_port}"
|
||||
proxy_url = f"{proxy_base_url}/v1/messages"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {args.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
stream_payload = {
|
||||
"model": DEFAULT_MODEL,
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
provider: Optional[MockAnthropicProvider] = None
|
||||
proxy_process: Optional[subprocess.Popen] = None
|
||||
with tempfile.TemporaryDirectory(prefix="litellm-anthropic-perf-") as tmp_dir_name:
|
||||
tmp_dir = Path(tmp_dir_name)
|
||||
proxy_log_path = tmp_dir / "proxy.log"
|
||||
if args.provider_url:
|
||||
provider_base_url = args.provider_url.rstrip("/")
|
||||
else:
|
||||
provider = MockAnthropicProvider(
|
||||
host=args.provider_host,
|
||||
port=args.provider_port,
|
||||
first_token_delay_ms=args.provider_first_token_delay_ms,
|
||||
stream_content_chunks=args.provider_stream_content_chunks,
|
||||
)
|
||||
await provider.start()
|
||||
provider_base_url = provider.base_url
|
||||
|
||||
config_path = tmp_dir / "config.yaml"
|
||||
write_proxy_config(config_path, provider_base_url, args.api_key)
|
||||
|
||||
try:
|
||||
if not args.no_start_proxy:
|
||||
proxy_process = start_proxy_process(
|
||||
litellm_dir=litellm_dir,
|
||||
proxy_command=args.proxy_command,
|
||||
config_path=config_path,
|
||||
port=args.proxy_port,
|
||||
log_path=proxy_log_path,
|
||||
)
|
||||
await wait_for_proxy(proxy_base_url, args.proxy_start_timeout)
|
||||
|
||||
runs: list[SummaryStats] = []
|
||||
for run_idx in range(max(1, args.repeats)):
|
||||
if args.repeats > 1:
|
||||
print(f"\n--- Run {run_idx + 1}/{args.repeats} ---")
|
||||
stats = await run_benchmark(
|
||||
url=proxy_url,
|
||||
headers=headers,
|
||||
payload=stream_payload,
|
||||
requests=args.requests,
|
||||
concurrency=args.concurrency,
|
||||
warmup=args.warmup,
|
||||
timeout_s=args.timeout,
|
||||
)
|
||||
runs.append(stats)
|
||||
if args.repeats > 1:
|
||||
print(
|
||||
f" run {run_idx + 1}: TTFT p50={stats.ttft_p50_ms:.2f}ms "
|
||||
f"TPM={stats.tokens_per_sec:.1f} tok/s RPS={stats.rps:.2f}"
|
||||
)
|
||||
|
||||
stats = sorted(runs, key=lambda s: s.ttft_p50_ms)[len(runs) // 2]
|
||||
finally:
|
||||
if proxy_process is not None:
|
||||
stop_proxy_process(proxy_process)
|
||||
if provider is not None:
|
||||
await provider.stop()
|
||||
|
||||
print_summary(args.label, revision, stats)
|
||||
|
||||
if args.output_json:
|
||||
Path(args.output_json).write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"label": args.label,
|
||||
"revision": revision,
|
||||
"proxy_streaming": stats_to_dict(stats),
|
||||
"proxy_log_path": str(proxy_log_path),
|
||||
},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
asyncio.run(async_main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+129
@@ -544,3 +544,132 @@ class TestThinkingSummaryPreservation:
|
||||
assert result == {
|
||||
"reasoning_effort": {"effort": "medium", "summary": "concise"}
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parity tests: redundant empty-text-block sanitization scan removal.
|
||||
# The async wrapper sanitizes once and tells the handler to skip its second
|
||||
# (redundant) full-messages scan; the sync entry point still sanitizes.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _empty_block_msgs():
|
||||
return [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": " "}, # whitespace-only -> stripped
|
||||
{"type": "tool_use", "id": "t", "name": "B", "input": {}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_handler_strips_when_no_presanitized_flag():
|
||||
"""Sync entry point (no async wrapper): handler must still sanitize."""
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages import handler
|
||||
|
||||
with patch.object(
|
||||
handler,
|
||||
"strip_empty_text_blocks_from_anthropic_messages",
|
||||
wraps=handler.strip_empty_text_blocks_from_anthropic_messages,
|
||||
) as spy:
|
||||
result = handler.anthropic_messages_handler(
|
||||
max_tokens=10,
|
||||
messages=_empty_block_msgs(),
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
custom_llm_provider="anthropic",
|
||||
mock_response="hi there",
|
||||
)
|
||||
assert spy.call_count == 1 # sanitized exactly once here
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_handler_skips_strip_when_presanitized():
|
||||
"""Async wrapper already sanitized -> handler must NOT rescan."""
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages import handler
|
||||
|
||||
with patch.object(
|
||||
handler,
|
||||
"strip_empty_text_blocks_from_anthropic_messages",
|
||||
wraps=handler.strip_empty_text_blocks_from_anthropic_messages,
|
||||
) as spy:
|
||||
result = handler.anthropic_messages_handler(
|
||||
max_tokens=10,
|
||||
messages=_empty_block_msgs(),
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
custom_llm_provider="anthropic",
|
||||
mock_response="hi there",
|
||||
_litellm_messages_presanitized=True,
|
||||
)
|
||||
assert spy.call_count == 0 # skipped the redundant scan
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_presanitized_flag_not_leaked_to_provider_params():
|
||||
"""The private sentinel must be popped, never forwarded as a request param."""
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages import handler
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_base_handler(*args, **kwargs):
|
||||
captured.update(kwargs)
|
||||
captured["optional"] = kwargs.get(
|
||||
"anthropic_messages_optional_request_params", {}
|
||||
)
|
||||
return "stub"
|
||||
|
||||
with patch.object(
|
||||
handler.base_llm_http_handler,
|
||||
"anthropic_messages_handler",
|
||||
side_effect=fake_base_handler,
|
||||
):
|
||||
handler.anthropic_messages_handler(
|
||||
max_tokens=10,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
custom_llm_provider="anthropic",
|
||||
_litellm_messages_presanitized=True,
|
||||
)
|
||||
|
||||
assert "_litellm_messages_presanitized" not in captured.get("optional", {})
|
||||
assert "_litellm_messages_presanitized" not in captured.get("kwargs", {})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_wrapper_sets_presanitized_and_sanitizes_once():
|
||||
"""End-to-end: wrapper sanitizes (once) AND signals the handler to skip."""
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages import handler
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_handler(*args, **kwargs):
|
||||
captured["messages"] = kwargs.get("messages")
|
||||
captured["presanitized"] = kwargs.get("_litellm_messages_presanitized")
|
||||
return "stub"
|
||||
|
||||
fake_loop = MagicMock()
|
||||
fake_loop.run_in_executor = lambda _e, func: _async_return(func())
|
||||
|
||||
with (
|
||||
patch.object(handler, "anthropic_messages_handler", side_effect=fake_handler),
|
||||
patch("asyncio.get_event_loop", return_value=fake_loop),
|
||||
patch.object(
|
||||
handler,
|
||||
"strip_empty_text_blocks_from_anthropic_messages",
|
||||
wraps=handler.strip_empty_text_blocks_from_anthropic_messages,
|
||||
) as spy,
|
||||
):
|
||||
await handler.anthropic_messages(
|
||||
max_tokens=100,
|
||||
messages=_empty_block_msgs(),
|
||||
model="anthropic/claude-sonnet-4-5-20250929",
|
||||
custom_llm_provider="anthropic",
|
||||
api_key="k",
|
||||
)
|
||||
|
||||
# Wrapper stripped exactly once (the handler is faked, so its skipped
|
||||
# call never runs anyway -- the point is the wrapper still sanitizes).
|
||||
assert spy.call_count == 1
|
||||
assert captured["presanitized"] is True
|
||||
assert [b["type"] for b in captured["messages"][0]["content"]] == ["tool_use"]
|
||||
|
||||
+56
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Regression tests for the /v1/messages request-parse fast paths:
|
||||
|
||||
- get_requested_anthropic_messages_optional_param must still filter to the
|
||||
valid AnthropicMessagesRequestOptionalParams keys and drop None values,
|
||||
while resolving the (static) type hints only once per process.
|
||||
"""
|
||||
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.utils import (
|
||||
AnthropicMessagesRequestUtils,
|
||||
_anthropic_messages_optional_param_keys,
|
||||
)
|
||||
|
||||
|
||||
def test_optional_param_filtering_unchanged():
|
||||
params = {
|
||||
"temperature": 0.5,
|
||||
"top_p": None, # None dropped
|
||||
"tools": [{"name": "x"}],
|
||||
"not_a_real_param": "drop me", # invalid key dropped
|
||||
"stream": True,
|
||||
}
|
||||
result = (
|
||||
AnthropicMessagesRequestUtils.get_requested_anthropic_messages_optional_param(
|
||||
params
|
||||
)
|
||||
)
|
||||
assert result == {"temperature": 0.5, "tools": [{"name": "x"}], "stream": True}
|
||||
assert "top_p" not in result
|
||||
assert "not_a_real_param" not in result
|
||||
|
||||
|
||||
def test_valid_keys_are_memoized():
|
||||
_anthropic_messages_optional_param_keys.cache_clear()
|
||||
first = _anthropic_messages_optional_param_keys()
|
||||
for _ in range(50):
|
||||
AnthropicMessagesRequestUtils.get_requested_anthropic_messages_optional_param(
|
||||
{"temperature": 0.1}
|
||||
)
|
||||
info = _anthropic_messages_optional_param_keys.cache_info()
|
||||
# Resolved exactly once despite many calls.
|
||||
assert info.misses == 1
|
||||
assert info.hits >= 50
|
||||
# Stable identity (frozenset) returned each call.
|
||||
assert _anthropic_messages_optional_param_keys() is first
|
||||
assert isinstance(first, frozenset)
|
||||
assert "temperature" in first and "tools" in first
|
||||
|
||||
|
||||
def test_empty_params():
|
||||
assert (
|
||||
AnthropicMessagesRequestUtils.get_requested_anthropic_messages_optional_param(
|
||||
{}
|
||||
)
|
||||
== {}
|
||||
)
|
||||
@@ -101,6 +101,73 @@ def test_get_agentic_loop_settings_defaults_and_overrides():
|
||||
assert fingerprints == ["fp-1", "fp-2"]
|
||||
|
||||
|
||||
def test_has_agentic_completion_hook_detection(monkeypatch):
|
||||
"""The streaming path skips the agentic wrapper only when no callback
|
||||
overrides async_should_run_agentic_loop. Verify both directions."""
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
handler = BaseLLMHTTPHandler()
|
||||
logging_obj = Mock()
|
||||
logging_obj.dynamic_success_callbacks = []
|
||||
|
||||
# No callbacks at all -> no agentic hook.
|
||||
monkeypatch.setattr(litellm, "callbacks", [])
|
||||
assert handler._has_agentic_completion_hook(logging_obj) is False
|
||||
|
||||
# A plain CustomLogger that does NOT override the gate -> still no hook
|
||||
# (so the wrapper is safely skipped).
|
||||
class _PlainLogger(CustomLogger):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_PlainLogger()])
|
||||
assert handler._has_agentic_completion_hook(logging_obj) is False
|
||||
|
||||
# A logger that overrides the gate (directly) -> hook present.
|
||||
class _AgenticLogger(CustomLogger):
|
||||
async def async_should_run_agentic_loop(
|
||||
self, response, model, messages, tools, stream, custom_llm_provider, kwargs
|
||||
):
|
||||
return True, {}
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_AgenticLogger()])
|
||||
assert handler._has_agentic_completion_hook(logging_obj) is True
|
||||
|
||||
# Override inherited through an intermediate class is still detected
|
||||
# (function-identity check, not a leaf __dict__ check).
|
||||
class _DerivedAgenticLogger(_AgenticLogger):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_DerivedAgenticLogger()])
|
||||
assert handler._has_agentic_completion_hook(logging_obj) is True
|
||||
|
||||
# Hook supplied via logging_obj.dynamic_success_callbacks is detected too.
|
||||
monkeypatch.setattr(litellm, "callbacks", [])
|
||||
logging_obj.dynamic_success_callbacks = [_AgenticLogger()]
|
||||
assert handler._has_agentic_completion_hook(logging_obj) is True
|
||||
|
||||
# String-named callback entry (e.g. "datadog") must be resolved to its
|
||||
# CustomLogger instance via get_custom_logger_compatible_class -- the same
|
||||
# way ProxyLogging._callback_capabilities handles them. Without that
|
||||
# resolution a string-registered agentic callback would be silently
|
||||
# skipped and the buffering wrapper would never fire.
|
||||
logging_obj.dynamic_success_callbacks = []
|
||||
agentic_via_string = _AgenticLogger()
|
||||
monkeypatch.setattr(litellm, "callbacks", ["fake_string_callback"])
|
||||
monkeypatch.setattr(
|
||||
"litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class",
|
||||
lambda name: agentic_via_string if name == "fake_string_callback" else None,
|
||||
)
|
||||
assert handler._has_agentic_completion_hook(logging_obj) is True
|
||||
|
||||
# Unresolvable string (returns None) is skipped, no false positive.
|
||||
monkeypatch.setattr(litellm, "callbacks", ["unknown_callback"])
|
||||
monkeypatch.setattr(
|
||||
"litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class",
|
||||
lambda name: None,
|
||||
)
|
||||
assert handler._has_agentic_completion_hook(logging_obj) is False
|
||||
|
||||
|
||||
def test_fingerprint_agentic_tools_is_deterministic():
|
||||
handler = BaseLLMHTTPHandler()
|
||||
tools_a = {"tool_calls": [{"id": "1", "input": {"q": "abc"}, "name": "web_search"}]}
|
||||
@@ -422,3 +489,143 @@ def test_sync_delete_responses_omits_body_for_azure():
|
||||
assert captured["url"].endswith(
|
||||
"/openai/responses/resp_xyz?api-version=2025-03-01-preview"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parity tests: request-body is serialized once and reused for the wire.
|
||||
# (_async_post_anthropic_messages_with_http_error_retry)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_post_uses_prebuilt_body_without_redumping():
|
||||
"""When the caller passes a pre-serialized (unsigned) body, attempt 0 must
|
||||
send exactly those bytes -- no second json.dumps of request_body."""
|
||||
import json as _json
|
||||
|
||||
handler = BaseLLMHTTPHandler()
|
||||
request_body = {"model": "claude", "messages": [{"role": "user", "content": "hi"}]}
|
||||
prebuilt = _json.dumps(request_body)
|
||||
|
||||
ok_resp = Mock()
|
||||
ok_resp.raise_for_status = Mock(return_value=None)
|
||||
http_client = Mock()
|
||||
http_client.post = AsyncMock(return_value=ok_resp)
|
||||
|
||||
provider_config = Mock()
|
||||
provider_config.max_retry_on_anthropic_messages_http_error = 2
|
||||
|
||||
logging_obj = Mock()
|
||||
logging_obj.model_call_details = {}
|
||||
|
||||
out = await handler._async_post_anthropic_messages_with_http_error_retry(
|
||||
async_httpx_client=http_client,
|
||||
request_url="http://x/v1/messages",
|
||||
headers={},
|
||||
signed_json_body=prebuilt,
|
||||
request_body=request_body,
|
||||
stream=False,
|
||||
logging_obj=logging_obj,
|
||||
provider_config=provider_config,
|
||||
litellm_params=GenericLiteLLMParams(),
|
||||
api_key="k",
|
||||
model="claude",
|
||||
)
|
||||
assert out is ok_resp
|
||||
http_client.post.assert_awaited_once()
|
||||
sent = http_client.post.await_args.kwargs["data"]
|
||||
# Byte-identical to the legacy wire serialization, and the SAME object the
|
||||
# caller already used for the pre-call log (no re-serialization).
|
||||
assert sent == prebuilt
|
||||
assert sent is prebuilt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_post_falls_back_to_json_dumps_when_unsigned_none():
|
||||
"""signed_json_body=None keeps the exact legacy behavior."""
|
||||
import json as _json
|
||||
|
||||
handler = BaseLLMHTTPHandler()
|
||||
request_body = {"model": "claude", "messages": [{"role": "user", "content": "yo"}]}
|
||||
|
||||
ok_resp = Mock()
|
||||
ok_resp.raise_for_status = Mock(return_value=None)
|
||||
http_client = Mock()
|
||||
http_client.post = AsyncMock(return_value=ok_resp)
|
||||
|
||||
provider_config = Mock()
|
||||
provider_config.max_retry_on_anthropic_messages_http_error = 1
|
||||
logging_obj = Mock()
|
||||
logging_obj.model_call_details = {}
|
||||
|
||||
await handler._async_post_anthropic_messages_with_http_error_retry(
|
||||
async_httpx_client=http_client,
|
||||
request_url="http://x/v1/messages",
|
||||
headers={},
|
||||
signed_json_body=None,
|
||||
request_body=request_body,
|
||||
stream=False,
|
||||
logging_obj=logging_obj,
|
||||
provider_config=provider_config,
|
||||
litellm_params=GenericLiteLLMParams(),
|
||||
api_key="k",
|
||||
model="claude",
|
||||
)
|
||||
sent = http_client.post.await_args.kwargs["data"]
|
||||
assert sent == _json.dumps(request_body)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_post_retry_reserializes_mutated_body():
|
||||
"""On a retryable HTTP error the body is mutated + re-signed; the prebuilt
|
||||
body must NOT be reused -- attempt 1 sends the freshly serialized body."""
|
||||
import json as _json
|
||||
|
||||
handler = BaseLLMHTTPHandler()
|
||||
request_body = {"model": "claude", "messages": [{"role": "user", "content": "a"}]}
|
||||
prebuilt = _json.dumps(request_body)
|
||||
|
||||
err_resp = Mock()
|
||||
http_error = httpx.HTTPStatusError(
|
||||
"bad", request=Mock(), response=Mock(status_code=400)
|
||||
)
|
||||
err_resp.raise_for_status = Mock(side_effect=http_error)
|
||||
ok_resp = Mock()
|
||||
ok_resp.raise_for_status = Mock(return_value=None)
|
||||
http_client = Mock()
|
||||
http_client.post = AsyncMock(side_effect=[err_resp, ok_resp])
|
||||
|
||||
def _mutate(e, request_data):
|
||||
request_data["messages"][0]["content"] = "MUTATED"
|
||||
|
||||
provider_config = Mock()
|
||||
provider_config.max_retry_on_anthropic_messages_http_error = 2
|
||||
provider_config.should_retry_anthropic_messages_on_http_error = Mock(
|
||||
return_value=True
|
||||
)
|
||||
provider_config.transform_anthropic_messages_request_on_http_error = _mutate
|
||||
# Re-sign returns no signed body (native anthropic path) -> must re-dump.
|
||||
provider_config.sign_request = Mock(return_value=({}, None))
|
||||
|
||||
logging_obj = Mock()
|
||||
logging_obj.model_call_details = {}
|
||||
|
||||
await handler._async_post_anthropic_messages_with_http_error_retry(
|
||||
async_httpx_client=http_client,
|
||||
request_url="http://x/v1/messages",
|
||||
headers={},
|
||||
signed_json_body=prebuilt,
|
||||
request_body=request_body,
|
||||
stream=False,
|
||||
logging_obj=logging_obj,
|
||||
provider_config=provider_config,
|
||||
litellm_params=GenericLiteLLMParams(),
|
||||
api_key="k",
|
||||
model="claude",
|
||||
)
|
||||
assert http_client.post.await_count == 2
|
||||
first_sent = http_client.post.await_args_list[0].kwargs["data"]
|
||||
second_sent = http_client.post.await_args_list[1].kwargs["data"]
|
||||
assert first_sent == prebuilt # attempt 0 used prebuilt
|
||||
assert second_sent == _json.dumps(request_body) # attempt 1 re-serialized
|
||||
assert "MUTATED" in second_sent # ... the mutated body
|
||||
|
||||
+359
@@ -684,3 +684,362 @@ class TestAnthropicBatchPassthroughCostTracking:
|
||||
mock_proxy_logging_obj.get_proxy_hook.assert_called_once_with(
|
||||
"managed_files"
|
||||
)
|
||||
|
||||
|
||||
class TestPureTextFastPathParity:
|
||||
"""
|
||||
The pure-text fast path in _build_complete_streaming_response must produce
|
||||
a response (and downstream logging/cost payload) byte-identical to the
|
||||
legacy stream_chunk_builder path. Anything non-text must fall back.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _sse(event, data):
|
||||
return f"event: {event}\ndata: {json.dumps(data)}\n\n".encode()
|
||||
|
||||
@staticmethod
|
||||
def _to_all_chunks(raw_frames):
|
||||
# Mirror production: raw bytes -> _convert_raw_bytes_to_str_lines.
|
||||
from litellm.proxy.pass_through_endpoints.streaming_handler import (
|
||||
PassThroughStreamingHandler,
|
||||
)
|
||||
|
||||
return PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_frames)
|
||||
|
||||
@staticmethod
|
||||
def _norm(resp):
|
||||
if resp is None:
|
||||
return None
|
||||
d = resp.model_dump()
|
||||
# id / created are non-deterministic even between two legacy runs.
|
||||
d.pop("id", None)
|
||||
d.pop("created", None)
|
||||
return d
|
||||
|
||||
def _text_stream(
|
||||
self,
|
||||
texts,
|
||||
*,
|
||||
input_tokens=12,
|
||||
cache_creation=0,
|
||||
cache_read=0,
|
||||
stop_reason="end_turn",
|
||||
with_ping=True,
|
||||
blocks=1,
|
||||
):
|
||||
frames = [
|
||||
self._sse(
|
||||
"message_start",
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_abc",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"content": [],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": cache_creation,
|
||||
"cache_read_input_tokens": cache_read,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
per_block = max(1, len(texts) // blocks)
|
||||
ti = 0
|
||||
for b in range(blocks):
|
||||
frames.append(
|
||||
self._sse(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": b,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
)
|
||||
)
|
||||
if with_ping:
|
||||
frames.append(self._sse("ping", {"type": "ping"}))
|
||||
chunk_texts = texts[ti : ti + per_block] if b < blocks - 1 else texts[ti:]
|
||||
ti += per_block
|
||||
for t in chunk_texts:
|
||||
frames.append(
|
||||
self._sse(
|
||||
"content_block_delta",
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": b,
|
||||
"delta": {"type": "text_delta", "text": t},
|
||||
},
|
||||
)
|
||||
)
|
||||
frames.append(
|
||||
self._sse(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": b}
|
||||
)
|
||||
)
|
||||
frames.append(
|
||||
self._sse(
|
||||
"message_delta",
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": stop_reason, "stop_sequence": None},
|
||||
"usage": {"output_tokens": len(texts)},
|
||||
},
|
||||
)
|
||||
)
|
||||
frames.append(self._sse("message_stop", {"type": "message_stop"}))
|
||||
return frames
|
||||
|
||||
def _assert_parity(self, raw_frames):
|
||||
all_chunks = self._to_all_chunks(raw_frames)
|
||||
lo1 = MagicMock()
|
||||
lo1.model_call_details = {}
|
||||
lo2 = MagicMock()
|
||||
lo2.model_call_details = {}
|
||||
|
||||
legacy = AnthropicPassthroughLoggingHandler._build_complete_streaming_response_legacy(
|
||||
all_chunks=list(all_chunks),
|
||||
litellm_logging_obj=lo1,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
)
|
||||
fast = AnthropicPassthroughLoggingHandler._build_complete_streaming_response(
|
||||
all_chunks=list(all_chunks),
|
||||
litellm_logging_obj=lo2,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
)
|
||||
assert self._norm(fast) == self._norm(legacy)
|
||||
|
||||
# Downstream logged/billed payload must also match.
|
||||
start = datetime.now()
|
||||
end = datetime.now()
|
||||
k_legacy = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
||||
litellm_model_response=legacy,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
kwargs={},
|
||||
start_time=start,
|
||||
end_time=end,
|
||||
logging_obj=lo1,
|
||||
)
|
||||
k_fast = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
||||
litellm_model_response=fast,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
kwargs={},
|
||||
start_time=start,
|
||||
end_time=end,
|
||||
logging_obj=lo2,
|
||||
)
|
||||
# Usage drives cost; it must be byte-identical between paths.
|
||||
assert getattr(fast, "usage", None) == getattr(legacy, "usage", None)
|
||||
|
||||
# And the full logged payload (sans non-deterministic response id).
|
||||
def _scrub(p):
|
||||
d = dict(p)
|
||||
r = d.get("complete_streaming_response_in_db") or d.get(
|
||||
"complete_streaming_response"
|
||||
)
|
||||
return d, getattr(r, "usage", None)
|
||||
|
||||
assert _scrub(k_fast)[1] == _scrub(k_legacy)[1]
|
||||
|
||||
def test_parity_simple_text(self):
|
||||
self._assert_parity(self._text_stream(["Hello", " ", "world", "!"]))
|
||||
|
||||
def test_parity_single_delta(self):
|
||||
self._assert_parity(self._text_stream(["Just one piece of text."]))
|
||||
|
||||
def test_parity_cache_tokens(self):
|
||||
self._assert_parity(
|
||||
self._text_stream(
|
||||
["a", "b", "c"], input_tokens=20, cache_creation=5, cache_read=7
|
||||
)
|
||||
)
|
||||
|
||||
def test_parity_max_tokens_stop(self):
|
||||
self._assert_parity(self._text_stream(["tok"] * 8, stop_reason="max_tokens"))
|
||||
|
||||
def test_parity_no_ping(self):
|
||||
self._assert_parity(self._text_stream(["x", "y"], with_ping=False))
|
||||
|
||||
def test_parity_empty_text_deltas(self):
|
||||
self._assert_parity(self._text_stream(["", "hi", "", "there"]))
|
||||
|
||||
def test_parity_multi_text_block(self):
|
||||
self._assert_parity(self._text_stream(["p1", "p2", "p3", "p4"], blocks=2))
|
||||
|
||||
def test_parity_multibyte_batched_frames(self):
|
||||
# Several SSE events delivered in one network chunk.
|
||||
frames = self._text_stream(["alpha", "beta", "gamma"])
|
||||
merged = b"".join(frames)
|
||||
self._assert_parity([merged])
|
||||
|
||||
def test_collapse_returns_none_for_tool_use(self):
|
||||
frames = [
|
||||
self._sse(
|
||||
"message_start",
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "m",
|
||||
"model": "x",
|
||||
"role": "assistant",
|
||||
"type": "message",
|
||||
"content": [],
|
||||
"usage": {"input_tokens": 1, "output_tokens": 0},
|
||||
},
|
||||
},
|
||||
),
|
||||
self._sse(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "tool_use",
|
||||
"id": "t1",
|
||||
"name": "get_weather",
|
||||
"input": {},
|
||||
},
|
||||
},
|
||||
),
|
||||
self._sse(
|
||||
"content_block_delta",
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "input_json_delta", "partial_json": "{}"},
|
||||
},
|
||||
),
|
||||
self._sse("content_block_stop", {"type": "content_block_stop", "index": 0}),
|
||||
self._sse("message_stop", {"type": "message_stop"}),
|
||||
]
|
||||
all_chunks = self._to_all_chunks(frames)
|
||||
assert (
|
||||
AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks(
|
||||
list(all_chunks)
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_collapse_returns_none_for_thinking(self):
|
||||
frames = [
|
||||
self._sse(
|
||||
"message_start",
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "m",
|
||||
"model": "x",
|
||||
"role": "assistant",
|
||||
"type": "message",
|
||||
"content": [],
|
||||
"usage": {"input_tokens": 1, "output_tokens": 0},
|
||||
},
|
||||
},
|
||||
),
|
||||
self._sse(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "thinking", "thinking": ""},
|
||||
},
|
||||
),
|
||||
self._sse(
|
||||
"content_block_delta",
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "thinking_delta", "thinking": "hmm"},
|
||||
},
|
||||
),
|
||||
self._sse("message_stop", {"type": "message_stop"}),
|
||||
]
|
||||
all_chunks = self._to_all_chunks(frames)
|
||||
assert (
|
||||
AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks(
|
||||
list(all_chunks)
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_collapse_actually_shrinks_chunk_count(self):
|
||||
frames = self._text_stream(["a"] * 50)
|
||||
all_chunks = list(self._to_all_chunks(frames))
|
||||
collapsed = AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks(
|
||||
all_chunks
|
||||
)
|
||||
assert collapsed is not None
|
||||
# 50 text deltas + 50 event markers + 1 ping collapse to far fewer.
|
||||
assert len(collapsed) < len(all_chunks) / 2
|
||||
|
||||
def test_collapse_returns_none_for_interleaved_block_indexes(self):
|
||||
"""
|
||||
Anthropic sends content blocks strictly sequentially (start/deltas/stop
|
||||
for one, then the next). If a stream ever interleaves deltas across
|
||||
block indexes, the fast path must bail to legacy rather than merge text
|
||||
from different blocks under a single index.
|
||||
"""
|
||||
frames = [
|
||||
self._sse(
|
||||
"message_start",
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_abc",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"content": [],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 1, "output_tokens": 0},
|
||||
},
|
||||
},
|
||||
),
|
||||
self._sse(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
),
|
||||
self._sse(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
),
|
||||
# Interleave: delta for block 0, then delta for block 1, with no
|
||||
# content_block_stop between them.
|
||||
self._sse(
|
||||
"content_block_delta",
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": "hello "},
|
||||
},
|
||||
),
|
||||
self._sse(
|
||||
"content_block_delta",
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 1,
|
||||
"delta": {"type": "text_delta", "text": "world"},
|
||||
},
|
||||
),
|
||||
self._sse("message_stop", {"type": "message_stop"}),
|
||||
]
|
||||
all_chunks = list(self._to_all_chunks(frames))
|
||||
assert (
|
||||
AnthropicPassthroughLoggingHandler._collapse_pure_text_chunks(all_chunks)
|
||||
is None
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
import litellm
|
||||
from litellm._uuid import uuid
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.opentelemetry import UserAPIKeyAuth
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
@@ -1316,8 +1317,17 @@ class TestCommonRequestProcessingHelpers:
|
||||
yield 'data: {"content": "chunk 3"}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
# Patch the tracer in the common_request_processing module
|
||||
with patch("litellm.proxy.common_request_processing.tracer", mock_tracer):
|
||||
# Patch the tracer in the common_request_processing module. The
|
||||
# per-chunk span is gated on _DD_STREAMING_TRACE_ENABLED (resolved at
|
||||
# import from the real tracer, a NullTracer by default), so enable it
|
||||
# explicitly to exercise the tracing path.
|
||||
with (
|
||||
patch("litellm.proxy.common_request_processing.tracer", mock_tracer),
|
||||
patch(
|
||||
"litellm.proxy.common_request_processing._DD_STREAMING_TRACE_ENABLED",
|
||||
True,
|
||||
),
|
||||
):
|
||||
response = await create_response(mock_generator(), "text/event-stream", {})
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -1345,6 +1355,40 @@ class TestCommonRequestProcessingHelpers:
|
||||
args[0] == "streaming.chunk.yield"
|
||||
), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}"
|
||||
|
||||
async def test_create_streaming_response_skips_dd_trace_when_disabled(self):
|
||||
"""When DD tracing is disabled (the default), the per-chunk span
|
||||
context manager is skipped entirely but all chunks still stream."""
|
||||
from unittest.mock import patch
|
||||
|
||||
mock_tracer = MagicMock()
|
||||
|
||||
async def mock_generator():
|
||||
yield 'data: {"content": "chunk 1"}\n\n'
|
||||
yield 'data: {"content": "chunk 2"}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
with (
|
||||
patch("litellm.proxy.common_request_processing.tracer", mock_tracer),
|
||||
patch(
|
||||
"litellm.proxy.common_request_processing._DD_STREAMING_TRACE_ENABLED",
|
||||
False,
|
||||
),
|
||||
):
|
||||
response = await create_response(mock_generator(), "text/event-stream", {})
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
content = await self.consume_stream(response)
|
||||
|
||||
# All chunks stream through unchanged ...
|
||||
assert content == [
|
||||
'data: {"content": "chunk 1"}\n\n',
|
||||
'data: {"content": "chunk 2"}\n\n',
|
||||
"data: [DONE]\n\n",
|
||||
]
|
||||
# ... but no per-chunk span was created.
|
||||
assert mock_tracer.trace.call_count == 0
|
||||
|
||||
async def test_create_streaming_response_dd_trace_with_error_chunk(self):
|
||||
"""
|
||||
Test that when the first chunk contains an error, JSONResponse is returned
|
||||
@@ -2199,3 +2243,77 @@ class TestHandleLLMApiExceptionDictDetail:
|
||||
proxy_exc = await self._invoke(exc)
|
||||
assert proxy_exc.message == "Content blocked by guardrail"
|
||||
assert proxy_exc.provider_specific_fields is None
|
||||
|
||||
|
||||
class TestAsyncStreamingDataGeneratorFastPath:
|
||||
"""Fast/slow path branching in async_streaming_data_generator."""
|
||||
|
||||
@staticmethod
|
||||
async def _aiter(items):
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_skips_per_chunk_hook(self, monkeypatch):
|
||||
"""With no callbacks/guardrails/cost-injection, chunks pass through
|
||||
unchanged and the per-chunk hook is NOT awaited."""
|
||||
monkeypatch.setattr(litellm, "callbacks", [])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=MagicMock())
|
||||
hook_spy = AsyncMock(side_effect=lambda **kw: kw["response"])
|
||||
monkeypatch.setattr(
|
||||
proxy_logging_obj, "async_post_call_streaming_hook", hook_spy
|
||||
)
|
||||
|
||||
chunks = [b"event: a\ndata: {}\n\n", b"event: b\ndata: {}\n\n"]
|
||||
out = [
|
||||
c
|
||||
async for c in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
|
||||
response=self._aiter(chunks),
|
||||
user_api_key_dict=MagicMock(spec=UserAPIKeyAuth),
|
||||
request_data={"model": "claude-x"},
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
serialize_chunk=ProxyBaseLLMRequestProcessing.return_sse_chunk,
|
||||
serialize_error=lambda e: "data: error\n\n",
|
||||
)
|
||||
]
|
||||
|
||||
assert out == chunks # bytes pass through return_sse_chunk untouched
|
||||
hook_spy.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_runs_per_chunk_hook(self, monkeypatch):
|
||||
"""A callback that overrides async_post_call_streaming_hook forces the
|
||||
slow path and the per-chunk hook is invoked."""
|
||||
|
||||
class _StreamingCb(CustomLogger):
|
||||
async def async_post_call_streaming_hook(self, user_api_key_dict, response):
|
||||
return response
|
||||
|
||||
cb = _StreamingCb()
|
||||
monkeypatch.setattr(litellm, "callbacks", [cb])
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=MagicMock())
|
||||
hook_spy = AsyncMock(side_effect=lambda **kw: kw["response"])
|
||||
monkeypatch.setattr(
|
||||
proxy_logging_obj, "async_post_call_streaming_hook", hook_spy
|
||||
)
|
||||
|
||||
out = [
|
||||
c
|
||||
async for c in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
|
||||
response=self._aiter([{"type": "message_stop"}]),
|
||||
user_api_key_dict=MagicMock(spec=UserAPIKeyAuth),
|
||||
request_data={"model": "claude-x"},
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
serialize_chunk=ProxyBaseLLMRequestProcessing.return_sse_chunk,
|
||||
serialize_error=lambda e: "data: error\n\n",
|
||||
)
|
||||
]
|
||||
|
||||
assert len(out) == 1
|
||||
hook_spy.assert_awaited_once()
|
||||
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
|
||||
@@ -111,6 +111,28 @@ def test_callback_capabilities_captures_iterator_override(monkeypatch):
|
||||
assert kind == "override"
|
||||
|
||||
|
||||
def test_callback_capabilities_detects_inherited_streaming_chunk_override(monkeypatch):
|
||||
"""
|
||||
``async_post_call_streaming_hook`` must be detected even when the override
|
||||
lives on an intermediate parent class — a vendor base class can carry the
|
||||
override and the registered class can add nothing else. Before this PR the
|
||||
hook was unconditionally invoked, so a leaf-class ``__dict__`` miss here
|
||||
would silently drop the inherited hook.
|
||||
"""
|
||||
ProxyLogging._callback_capabilities_cache.clear()
|
||||
|
||||
class _StreamingBase(CustomLogger):
|
||||
async def async_post_call_streaming_hook(self, *args, **kwargs): # type: ignore[override]
|
||||
return kwargs.get("response")
|
||||
|
||||
class _LeafWithoutOverride(_StreamingBase):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(litellm, "callbacks", [_LeafWithoutOverride()])
|
||||
caps = ProxyLogging._callback_capabilities()
|
||||
assert caps.has_streaming_chunk_override is True
|
||||
|
||||
|
||||
def test_callback_capabilities_cache_invalidates_on_list_change(monkeypatch):
|
||||
"""The cache key includes (length, id-of-each-callback). Mutating the
|
||||
callback list must produce a fresh capability snapshot."""
|
||||
|
||||
Reference in New Issue
Block a user