Files
litellm/tests/test_litellm/proxy/test_common_request_processing.py
T
Yassin Kortam 2eab9ee2c0 perf: reduce per-request and per-chunk overhead across Anthropic streaming hot paths (#28289)
* perf: reduce per-request and per-chunk overhead across Anthropic streaming hot paths

- Introduce pure-text fast-path in `_build_complete_streaming_response` that collapses O(N) `content_block_delta` events into a single equivalent SSE event before conversion, eliminating per-output-token Pydantic `ModelResponseStream` construction; non-text streams (tool_use, thinking, citations) fall back to the unchanged legacy path
- Skip agentic streaming wrapper entirely when no callback overrides `async_should_run_agentic_loop`; the wrapper buffered every chunk and rebuilt the SSE response only to call hooks that all return `(False, {})` — a pure no-op for the default config
- Serialize request body once (`json.dumps`) for both the pre-call log input and the wire, instead of twice; avoids a full O(payload) scan per request, significant for long-context Claude Code histories
- Add fast path in `async_streaming_data_generator` that bypasses the per-chunk `async_post_call_streaming_hook` coroutine await, response-string materialization, and cost-injection call when no callback/guardrail/cost-injection is active (the default config)
- Resolve `_DD_STREAMING_TRACE_ENABLED` once at import time; eliminate per-chunk `NullSpan` context manager allocation when Datadog tracing is disabled (the default)
- Memoize `get_type_hints(AnthropicMessagesRequestOptionalParams)` with `@lru_cache(maxsize=1)` — resolves once per process instead of once per `/v1/messages` request (~80µs each)
- Hoist `cost_injection_active` out of the per-chunk loop in `chunk_processor`; eliminates repeated `getattr` + endpoint-type checks on every streamed byte chunk
- Extract `_build_passthrough_logging_result` from `_route_streaming_logging_to_handler` as a standalone static method to facilitate future off-loop dispatch
- Convert `async_sse_data_generator` from an `async for: yield` trampoline to a direct return of the underlying generator, removing one async-generator layer per streamed chunk
- Skip redundant `strip_empty_text_blocks_from_anthropic_messages` scan in `anthropic_messages_handler` when the async wrapper already sanitized (signalled via `_litellm_messages_presanitized` sentinel, popped before reaching provider params)
- Gate debug log `f-string` evaluation behind `isEnabledFor(DEBUG)` in both the streaming generator and the transformation layer to avoid serializing entire message payloads on every request at non-debug log levels
- Add benchmark script (`scripts/benchmark_anthropic_messages_perf.py`) with a local mock Anthropic SSE provider for reproducible TTFT and TPM measurement across commits/branches
- Add parity tests asserting fast-path and legacy-path produce byte-identical logged/billed payloads, plus unit tests for agentic hook detection, pre-serialized body reuse, and memoized key resolution

* perf: address greptile review for anthropic streaming hot path

- Bail to legacy in `_collapse_pure_text_chunks` when content_block_delta
  events from different block indexes are observed without an intervening
  flush. Anthropic sends blocks strictly sequentially, but defensive bail
  prevents silent text-merging if the protocol ever interleaves.
- Replace leaf-class `__dict__` check for `async_post_call_streaming_hook`
  in `_callback_capabilities` with a function-identity comparison that
  walks the MRO. A vendor base class can carry the override and the
  registered class can add nothing else; before this PR the hook was
  unconditionally invoked, so an inherited-override miss would silently
  drop the hook on the streaming path.
- Add unit tests for both behaviors.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(mypy): narrow model_name to str in cost-injection branch

The hoisted cost_injection_active flag in chunk_processor encodes the
`bool(model_name)` requirement but mypy can't track that invariant
through the local, so the per-chunk `_process_chunk_with_cost_injection(
chunk, model_name)` calls flagged Optional[str] vs str. Pin a typed
non-None local inside the cost-injection branch so mypy narrows
correctly without changing runtime behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Yassin Kortam <yassinkortam@g.ucla.edu>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-23 12:15:59 -07:00

2320 lines
88 KiB
Python

import copy
import datetime
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse, StreamingResponse
import litellm
from litellm._uuid import uuid
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.opentelemetry import UserAPIKeyAuth
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
ProxyConfig,
_extract_error_from_sse_chunk,
_get_cost_breakdown_from_logging_obj,
_has_attribute_error_in_chain,
_is_azure_model_router_request,
_override_openai_response_model,
_parse_event_data_for_error,
create_response,
)
from litellm.proxy.dd_span_tagger import DDSpanTagger
from litellm.proxy.utils import ProxyLogging
class TestProxyBaseLLMRequestProcessing:
@pytest.mark.asyncio
async def test_base_passthrough_process_llm_request_preserves_litellm_headers_for_non_streaming_response(
self, monkeypatch
):
processing_obj = ProxyBaseLLMRequestProcessing(data={})
async def fake_base_process_llm_request(**kwargs):
passthrough_response = kwargs["fastapi_response"]
passthrough_response.headers["x-litellm-call-id"] = "test-call-id"
passthrough_response.headers["x-litellm-version"] = "test-version"
return httpx.Response(
status_code=200,
content=b'{"ok":true}',
headers={
"content-type": "application/json",
"x-amzn-requestid": "bedrock-request-id",
},
)
monkeypatch.setattr(
processing_obj,
"base_process_llm_request",
fake_base_process_llm_request,
)
result = await processing_obj.base_passthrough_process_llm_request(
request=MagicMock(spec=Request),
fastapi_response=Response(),
user_api_key_dict=MagicMock(spec=UserAPIKeyAuth),
proxy_logging_obj=MagicMock(spec=ProxyLogging),
general_settings={},
proxy_config=MagicMock(spec=ProxyConfig),
select_data_generator=MagicMock(),
model="bedrock-test-model",
)
assert result.status_code == 200
assert result.body == b'{"ok":true}'
assert result.headers["x-amzn-requestid"] == "bedrock-request-id"
assert result.headers["x-litellm-call-id"] == "test-call-id"
assert result.headers["x-litellm-version"] == "test-version"
@pytest.mark.asyncio
async def test_common_processing_pre_call_logic_pre_call_hook_receives_litellm_call_id(
self, monkeypatch
):
processing_obj = ProxyBaseLLMRequestProcessing(data={})
mock_request = MagicMock(spec=Request)
mock_request.headers = {}
async def mock_add_litellm_data_to_request(*args, **kwargs):
return {}
async def mock_common_processing_pre_call_logic(
user_api_key_dict, data, call_type
):
data_copy = copy.deepcopy(data)
return data_copy
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
mock_proxy_logging_obj.pre_call_hook = AsyncMock(
side_effect=mock_common_processing_pre_call_logic
)
monkeypatch.setattr(
litellm.proxy.common_request_processing,
"add_litellm_data_to_request",
mock_add_litellm_data_to_request,
)
mock_general_settings = {}
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_proxy_config = MagicMock(spec=ProxyConfig)
route_type = "acompletion"
# Call the actual method.
(
returned_data,
logging_obj,
) = await processing_obj.common_processing_pre_call_logic(
request=mock_request,
general_settings=mock_general_settings,
user_api_key_dict=mock_user_api_key_dict,
proxy_logging_obj=mock_proxy_logging_obj,
proxy_config=mock_proxy_config,
route_type=route_type,
)
mock_proxy_logging_obj.pre_call_hook.assert_called_once()
_, call_kwargs = mock_proxy_logging_obj.pre_call_hook.call_args
data_passed = call_kwargs.get("data", {})
assert "litellm_call_id" in data_passed
try:
uuid.UUID(data_passed["litellm_call_id"])
except ValueError:
pytest.fail("litellm_call_id is not a valid UUID")
assert data_passed["litellm_call_id"] == returned_data["litellm_call_id"]
def test_add_dd_apm_tags_for_litellm_call_id_uses_dd_tracing_helper(
self, monkeypatch
):
mock_set_active_span_tag = MagicMock(return_value=True)
import litellm.proxy.dd_span_tagger
monkeypatch.setattr(
litellm.proxy.dd_span_tagger,
"set_active_span_tag",
mock_set_active_span_tag,
)
DDSpanTagger.tag_call_id("test-call-id")
mock_set_active_span_tag.assert_called_once_with(
"litellm.call_id", "test-call-id"
)
@pytest.mark.asyncio
async def test_should_apply_hierarchical_router_settings_as_override(
self, monkeypatch
):
"""
Test that hierarchical router settings are stored as router_settings_override
instead of creating a full user_config with model_list.
This approach avoids expensive per-request Router instantiation by passing
settings as kwargs overrides to the main router.
"""
processing_obj = ProxyBaseLLMRequestProcessing(data={})
mock_request = MagicMock(spec=Request)
mock_request.headers = {}
async def mock_add_litellm_data_to_request(*args, **kwargs):
return {}
async def mock_common_processing_pre_call_logic(
user_api_key_dict, data, call_type
):
data_copy = copy.deepcopy(data)
return data_copy
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
mock_proxy_logging_obj.pre_call_hook = AsyncMock(
side_effect=mock_common_processing_pre_call_logic
)
monkeypatch.setattr(
litellm.proxy.common_request_processing,
"add_litellm_data_to_request",
mock_add_litellm_data_to_request,
)
mock_general_settings = {}
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_proxy_config = MagicMock(spec=ProxyConfig)
mock_router_settings = {
"routing_strategy": "least-busy",
"timeout": 30.0,
"num_retries": 3,
}
mock_proxy_config._get_hierarchical_router_settings = AsyncMock(
return_value=mock_router_settings
)
mock_llm_router = MagicMock()
mock_prisma_client = MagicMock()
monkeypatch.setattr(
"litellm.proxy.proxy_server.prisma_client",
mock_prisma_client,
)
route_type = "acompletion"
(
returned_data,
logging_obj,
) = await processing_obj.common_processing_pre_call_logic(
request=mock_request,
general_settings=mock_general_settings,
user_api_key_dict=mock_user_api_key_dict,
proxy_logging_obj=mock_proxy_logging_obj,
proxy_config=mock_proxy_config,
route_type=route_type,
llm_router=mock_llm_router,
)
mock_proxy_config._get_hierarchical_router_settings.assert_called_once_with(
user_api_key_dict=mock_user_api_key_dict,
prisma_client=mock_prisma_client,
proxy_logging_obj=mock_proxy_logging_obj,
)
# get_model_list should NOT be called - we no longer copy model list for per-request routers
mock_llm_router.get_model_list.assert_not_called()
# Settings should be stored as router_settings_override (not user_config)
# This allows passing them as kwargs to the main router instead of creating a new one
assert "router_settings_override" in returned_data
assert "user_config" not in returned_data
router_settings_override = returned_data["router_settings_override"]
assert router_settings_override["routing_strategy"] == "least-busy"
assert router_settings_override["timeout"] == 30.0
assert router_settings_override["num_retries"] == 3
# model_list should NOT be in the override settings
assert "model_list" not in router_settings_override
@pytest.mark.asyncio
async def test_stream_timeout_header_processing(self):
"""
Test that x-litellm-stream-timeout header gets processed and added to request data as stream_timeout.
"""
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
# Test with stream timeout header
headers_with_timeout = {"x-litellm-stream-timeout": "30.5"}
result = LiteLLMProxyRequestSetup._get_stream_timeout_from_request(
headers_with_timeout
)
assert result == 30.5
# Test without stream timeout header
headers_without_timeout = {}
result = LiteLLMProxyRequestSetup._get_stream_timeout_from_request(
headers_without_timeout
)
assert result is None
# Test with invalid header value (should raise ValueError when converting to float)
headers_with_invalid = {"x-litellm-stream-timeout": "invalid"}
with pytest.raises(ValueError):
LiteLLMProxyRequestSetup._get_stream_timeout_from_request(
headers_with_invalid
)
@pytest.mark.asyncio
async def test_build_litellm_proxy_success_headers_from_llm_response(self):
"""
Google native :generateContent uses this helper instead of base_process_llm_request;
ensure x-litellm-* headers and callback hooks merge like the main proxy path.
"""
mock_request = MagicMock(spec=Request)
mock_request.headers = {}
class _FakeGenaiResponse:
_hidden_params = {
"model_id": "deployment-model-id",
"cache_key": "ck-test",
"api_base": "https://generativelanguage.googleapis.com/v1beta",
"response_cost": 0.001,
"additional_headers": {"llm_provider-ratelimit-requests": "1000"},
}
logging_obj = MagicMock()
logging_obj.litellm_call_id = "call-id-test"
mock_user = MagicMock()
mock_user.tpm_limit = None
mock_user.rpm_limit = None
mock_user.max_budget = None
mock_user.spend = 0.0
mock_user.allowed_model_region = None
proxy_logging_obj = MagicMock(spec=ProxyLogging)
proxy_logging_obj.post_call_response_headers_hook = AsyncMock(
return_value={"x-ratelimit-remaining-requests": "999"}
)
headers = await ProxyBaseLLMRequestProcessing.build_litellm_proxy_success_headers_from_llm_response(
response=_FakeGenaiResponse(),
request_data={"model": "gemini/gemini-1.5-flash"},
request=mock_request,
user_api_key_dict=mock_user,
logging_obj=logging_obj,
version="9.9.9",
proxy_logging_obj=proxy_logging_obj,
)
assert headers["x-litellm-call-id"] == "call-id-test"
assert headers["x-litellm-model-id"] == "deployment-model-id"
assert headers["x-litellm-version"] == "9.9.9"
assert headers["llm_provider-ratelimit-requests"] == "1000"
assert headers["x-ratelimit-remaining-requests"] == "999"
proxy_logging_obj.post_call_response_headers_hook.assert_awaited_once()
@pytest.mark.asyncio
async def test_build_litellm_proxy_success_headers_streaming_style_iterator(self):
"""AsyncGoogleGenAIGenerateContentStreamingIterator sets _hidden_params at init; headers must propagate."""
class _FakeStreamLike:
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
_hidden_params = {
"model_id": "stream-model-id",
"api_base": "https://generativelanguage.googleapis.com/v1beta",
"cache_key": "",
"response_cost": "",
"additional_headers": {"llm_provider-x": "y"},
}
mock_request = MagicMock(spec=Request)
mock_request.headers = {}
logging_obj = MagicMock()
logging_obj.litellm_call_id = "cid-stream"
mock_user = MagicMock()
mock_user.tpm_limit = None
mock_user.rpm_limit = None
mock_user.max_budget = None
mock_user.spend = 0.0
mock_user.allowed_model_region = None
proxy_logging_obj = MagicMock(spec=ProxyLogging)
proxy_logging_obj.post_call_response_headers_hook = AsyncMock(return_value={})
headers = await ProxyBaseLLMRequestProcessing.build_litellm_proxy_success_headers_from_llm_response(
response=_FakeStreamLike(),
request_data={"model": "gemini/gemini-2.0-flash"},
request=mock_request,
user_api_key_dict=mock_user,
logging_obj=logging_obj,
version="1.0.0",
proxy_logging_obj=proxy_logging_obj,
)
assert headers["x-litellm-model-id"] == "stream-model-id"
assert headers["x-litellm-model-api-base"] == (
"https://generativelanguage.googleapis.com/v1beta"
)
assert headers["llm_provider-x"] == "y"
@pytest.mark.asyncio
async def test_build_litellm_proxy_success_headers_no_hidden_params_metadata_fallback(
self,
):
"""When response has no _hidden_params, model_id can still come from litellm_metadata."""
class _BareResponse:
pass
mock_request = MagicMock(spec=Request)
mock_request.headers = {}
logging_obj = MagicMock()
logging_obj.litellm_call_id = "cid-meta"
mock_user = MagicMock()
mock_user.tpm_limit = None
mock_user.rpm_limit = None
mock_user.max_budget = None
mock_user.spend = 0.0
mock_user.allowed_model_region = None
proxy_logging_obj = MagicMock(spec=ProxyLogging)
proxy_logging_obj.post_call_response_headers_hook = AsyncMock(return_value={})
headers = await ProxyBaseLLMRequestProcessing.build_litellm_proxy_success_headers_from_llm_response(
response=_BareResponse(),
request_data={
"model": "gemini/gemini-1.5-flash",
"litellm_metadata": {"model_info": {"id": "meta-model-id"}},
},
request=mock_request,
user_api_key_dict=mock_user,
logging_obj=logging_obj,
version="1.0.0",
proxy_logging_obj=proxy_logging_obj,
)
assert headers["x-litellm-model-id"] == "meta-model-id"
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_with_stream_timeout_header(self):
"""
Test that x-litellm-stream-timeout header gets processed and added to request data
when calling add_litellm_data_to_request.
"""
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Create test data with a basic completion request
test_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
}
# Mock request with stream timeout header
mock_request = MagicMock(spec=Request)
mock_request.headers = {"x-litellm-stream-timeout": "45.0"}
mock_request.url.path = "/v1/chat/completions"
mock_request.method = "POST"
mock_request.query_params = {}
mock_request.client = None
# Create a minimal mock with just the required attributes
mock_user_api_key_dict = MagicMock()
mock_user_api_key_dict.api_key = "test_api_key_hash"
mock_user_api_key_dict.tpm_limit = None
mock_user_api_key_dict.rpm_limit = None
mock_user_api_key_dict.max_budget = None
mock_user_api_key_dict.spend = 0
mock_user_api_key_dict.allowed_model_region = None
mock_user_api_key_dict.key_alias = None
mock_user_api_key_dict.user_id = None
mock_user_api_key_dict.team_id = None
mock_user_api_key_dict.metadata = {} # Prevent enterprise feature check
mock_user_api_key_dict.team_metadata = None
mock_user_api_key_dict.org_id = None
mock_user_api_key_dict.team_alias = None
mock_user_api_key_dict.end_user_id = None
mock_user_api_key_dict.user_email = None
mock_user_api_key_dict.request_route = None
mock_user_api_key_dict.team_max_budget = None
mock_user_api_key_dict.team_spend = None
mock_user_api_key_dict.model_max_budget = None
mock_user_api_key_dict.parent_otel_span = None
mock_user_api_key_dict.team_model_aliases = None
general_settings = {}
mock_proxy_config = MagicMock()
# Call the actual function that processes headers and adds data
result_data = await add_litellm_data_to_request(
data=test_data,
request=mock_request,
general_settings=general_settings,
user_api_key_dict=mock_user_api_key_dict,
version=None,
proxy_config=mock_proxy_config,
)
# Verify that stream_timeout was extracted from header and added to request data
assert "stream_timeout" in result_data
assert result_data["stream_timeout"] == 45.0
# Verify that the original test data is preserved
assert result_data["model"] == "gpt-3.5-turbo"
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
def test_get_custom_headers_with_discount_info(self):
"""
Test that discount information is correctly extracted from logging object
and included in response headers.
"""
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
# Create mock user API key dict
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_user_api_key_dict.tpm_limit = None
mock_user_api_key_dict.rpm_limit = None
mock_user_api_key_dict.max_budget = None
mock_user_api_key_dict.spend = 0
# Create logging object with cost breakdown including discount
logging_obj = LiteLLMLoggingObj(
model="vertex_ai/gemini-pro",
messages=[{"role": "user", "content": "test"}],
stream=False,
call_type="completion",
start_time=None,
litellm_call_id="test-call-id",
function_id="test-function-id",
)
# Set cost breakdown with discount information
logging_obj.set_cost_breakdown(
input_cost=0.00005,
output_cost=0.00005,
total_cost=0.000095, # After 5% discount
cost_for_built_in_tools_cost_usd_dollar=0.0,
original_cost=0.0001,
discount_percent=0.05,
discount_amount=0.000005,
)
# Call get_custom_headers with discount info
headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id",
response_cost=0.000095,
litellm_logging_obj=logging_obj,
)
# Verify discount headers are present
assert "x-litellm-response-cost" in headers
assert float(headers["x-litellm-response-cost"]) == 0.000095
assert "x-litellm-response-cost-original" in headers
assert float(headers["x-litellm-response-cost-original"]) == 0.0001
assert "x-litellm-response-cost-discount-amount" in headers
assert float(headers["x-litellm-response-cost-discount-amount"]) == 0.000005
def test_get_custom_headers_without_discount_info(self):
"""
Test that when no discount is applied, discount headers are not included.
"""
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
# Create mock user API key dict
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_user_api_key_dict.tpm_limit = None
mock_user_api_key_dict.rpm_limit = None
mock_user_api_key_dict.max_budget = None
mock_user_api_key_dict.spend = 0
# Create logging object without discount
logging_obj = LiteLLMLoggingObj(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "test"}],
stream=False,
call_type="completion",
start_time=None,
litellm_call_id="test-call-id",
function_id="test-function-id",
)
# Set cost breakdown without discount information
logging_obj.set_cost_breakdown(
input_cost=0.00005,
output_cost=0.00005,
total_cost=0.0001,
cost_for_built_in_tools_cost_usd_dollar=0.0,
)
# Call get_custom_headers
headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id",
response_cost=0.0001,
litellm_logging_obj=logging_obj,
)
# Verify discount headers are NOT present
assert "x-litellm-response-cost" in headers
assert float(headers["x-litellm-response-cost"]) == 0.0001
# Discount headers should not be in the final dict
assert "x-litellm-response-cost-original" not in headers
assert "x-litellm-response-cost-discount-amount" not in headers
def test_get_custom_headers_with_margin_info(self):
"""
Test that margin headers are included when margin is applied.
"""
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
# Create mock user API key dict
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_user_api_key_dict.tpm_limit = None
mock_user_api_key_dict.rpm_limit = None
mock_user_api_key_dict.max_budget = None
mock_user_api_key_dict.spend = 0
# Create logging object with margin
logging_obj = LiteLLMLoggingObj(
model="gpt-4",
messages=[],
stream=False,
call_type="completion",
start_time=None,
litellm_call_id="test-call-id-margin",
function_id="test-function",
)
logging_obj.set_cost_breakdown(
input_cost=0.00005,
output_cost=0.00005,
total_cost=0.00011,
cost_for_built_in_tools_cost_usd_dollar=0.0,
original_cost=0.0001,
margin_percent=0.10,
margin_total_amount=0.00001,
)
headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
response_cost=0.00011,
litellm_logging_obj=logging_obj,
)
# Verify margin headers are present
assert "x-litellm-response-cost" in headers
assert float(headers["x-litellm-response-cost"]) == 0.00011
assert "x-litellm-response-cost-margin-amount" in headers
assert float(headers["x-litellm-response-cost-margin-amount"]) == 0.00001
assert "x-litellm-response-cost-margin-percent" in headers
assert float(headers["x-litellm-response-cost-margin-percent"]) == 0.10
def test_get_custom_headers_without_margin_info(self):
"""
Test that when no margin is applied, margin headers are not included.
"""
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
# Create mock user API key dict
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_user_api_key_dict.tpm_limit = None
mock_user_api_key_dict.rpm_limit = None
mock_user_api_key_dict.max_budget = None
mock_user_api_key_dict.spend = 0
# Create logging object without margin
logging_obj = LiteLLMLoggingObj(
model="gpt-4",
messages=[],
stream=False,
call_type="completion",
start_time=None,
litellm_call_id="test-call-id-no-margin",
function_id="test-function",
)
logging_obj.set_cost_breakdown(
input_cost=0.00005,
output_cost=0.00005,
total_cost=0.0001,
cost_for_built_in_tools_cost_usd_dollar=0.0,
)
headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
response_cost=0.0001,
litellm_logging_obj=logging_obj,
)
# Verify margin headers are not present
assert "x-litellm-response-cost-margin-amount" not in headers
assert "x-litellm-response-cost-margin-percent" not in headers
def test_get_cost_breakdown_from_logging_obj_helper(self):
"""
Test the helper function that extracts cost breakdown information.
"""
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
# Test with discount info
logging_obj = LiteLLMLoggingObj(
model="vertex_ai/gemini-pro",
messages=[{"role": "user", "content": "test"}],
stream=False,
call_type="completion",
start_time=None,
litellm_call_id="test-call-id",
function_id="test-function-id",
)
logging_obj.set_cost_breakdown(
input_cost=0.00005,
output_cost=0.00005,
total_cost=0.000095,
cost_for_built_in_tools_cost_usd_dollar=0.0,
original_cost=0.0001,
discount_percent=0.05,
discount_amount=0.000005,
)
(
original_cost,
discount_amount,
margin_total_amount,
margin_percent,
) = _get_cost_breakdown_from_logging_obj(logging_obj)
assert original_cost == 0.0001
assert discount_amount == 0.000005
assert margin_total_amount is None
assert margin_percent is None
# Test with margin info
logging_obj_with_margin = LiteLLMLoggingObj(
model="gpt-4",
messages=[{"role": "user", "content": "test"}],
stream=False,
call_type="completion",
start_time=None,
litellm_call_id="test-call-id-margin",
function_id="test-function-id-margin",
)
logging_obj_with_margin.set_cost_breakdown(
input_cost=0.00005,
output_cost=0.00005,
total_cost=0.00011,
cost_for_built_in_tools_cost_usd_dollar=0.0,
original_cost=0.0001,
margin_percent=0.10,
margin_total_amount=0.00001,
)
(
original_cost,
discount_amount,
margin_total_amount,
margin_percent,
) = _get_cost_breakdown_from_logging_obj(logging_obj_with_margin)
assert original_cost == 0.0001
assert discount_amount is None
assert margin_total_amount == 0.00001
assert margin_percent == 0.10
# Test with no discount or margin info
logging_obj_no_discount = LiteLLMLoggingObj(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "test"}],
stream=False,
call_type="completion",
start_time=None,
litellm_call_id="test-call-id-2",
function_id="test-function-id-2",
)
logging_obj_no_discount.set_cost_breakdown(
input_cost=0.00005,
output_cost=0.00005,
total_cost=0.0001,
cost_for_built_in_tools_cost_usd_dollar=0.0,
)
(
original_cost,
discount_amount,
margin_total_amount,
margin_percent,
) = _get_cost_breakdown_from_logging_obj(logging_obj_no_discount)
assert original_cost is None
assert discount_amount is None
assert margin_total_amount is None
assert margin_percent is None
# Test with None logging object
(
original_cost,
discount_amount,
margin_total_amount,
margin_percent,
) = _get_cost_breakdown_from_logging_obj(None)
assert original_cost is None
assert discount_amount is None
assert margin_total_amount is None
assert margin_percent is None
def test_get_custom_headers_key_spend_includes_response_cost(self):
"""
Test that x-litellm-key-spend header includes the current request's response_cost.
This ensures that the spend header reflects the updated spend including the current
request, even though spend tracking updates happen asynchronously after the response.
"""
# Create mock user API key dict with initial spend
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_user_api_key_dict.tpm_limit = None
mock_user_api_key_dict.rpm_limit = None
mock_user_api_key_dict.max_budget = None
mock_user_api_key_dict.spend = 0.001 # Initial spend: $0.001
# Test case 1: response_cost is provided as float
response_cost_1 = 0.0005 # Current request cost: $0.0005
headers_1 = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id-1",
response_cost=response_cost_1,
)
assert "x-litellm-key-spend" in headers_1
expected_spend_1 = 0.001 + 0.0005 # Initial spend + current request cost
assert float(headers_1["x-litellm-key-spend"]) == pytest.approx(
expected_spend_1, abs=1e-10
)
assert float(headers_1["x-litellm-response-cost"]) == response_cost_1
# Test case 2: response_cost is provided as string
response_cost_2 = "0.0003" # Current request cost as string
headers_2 = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id-2",
response_cost=response_cost_2,
)
assert "x-litellm-key-spend" in headers_2
expected_spend_2 = 0.001 + 0.0003 # Initial spend + current request cost
assert float(headers_2["x-litellm-key-spend"]) == pytest.approx(
expected_spend_2, abs=1e-10
)
# Test case 3: response_cost is None (should use original spend)
headers_3 = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id-3",
response_cost=None,
)
assert "x-litellm-key-spend" in headers_3
assert (
float(headers_3["x-litellm-key-spend"]) == 0.001
) # Should use original spend
# Test case 4: response_cost is 0 (should not change spend)
headers_4 = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id-4",
response_cost=0.0,
)
assert "x-litellm-key-spend" in headers_4
assert (
float(headers_4["x-litellm-key-spend"]) == 0.001
) # Should remain unchanged for 0 cost
# Test case 5: user_api_key_dict.spend is None (should default to 0.0)
mock_user_api_key_dict.spend = None
headers_5 = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id-5",
response_cost=0.0002,
)
assert "x-litellm-key-spend" in headers_5
assert float(headers_5["x-litellm-key-spend"]) == 0.0002 # 0.0 + 0.0002
# Test case 6: response_cost is negative (should not be added, use original spend)
mock_user_api_key_dict.spend = 0.001
headers_6 = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id-6",
response_cost=-0.0001, # Negative cost (should not be added)
)
assert "x-litellm-key-spend" in headers_6
assert (
float(headers_6["x-litellm-key-spend"]) == 0.001
) # Should use original spend
# Test case 7: response_cost is invalid string (should fallback to original spend)
headers_7 = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id-7",
response_cost="invalid", # Invalid string
)
assert "x-litellm-key-spend" in headers_7
assert (
float(headers_7["x-litellm-key-spend"]) == 0.001
) # Should use original spend on error
@pytest.mark.asyncio
async def test_queue_time_seconds_is_set_in_metadata(self, monkeypatch):
"""
Test that queue_time_seconds is correctly calculated and stored in metadata
after add_litellm_data_to_request populates arrival_time.
This verifies the fix for the bug where queue_time_seconds was always None
because arrival_time was read BEFORE add_litellm_data_to_request set it.
"""
processing_obj = ProxyBaseLLMRequestProcessing(data={})
mock_request = MagicMock(spec=Request)
mock_request.headers = {}
mock_request.url = MagicMock()
mock_request.url.path = "/v1/chat/completions"
async def mock_add_litellm_data_to_request(*args, **kwargs):
data = kwargs.get("data", args[0] if args else {})
# Simulate what add_litellm_data_to_request does: set arrival_time
import time
data["proxy_server_request"] = {
"url": "/v1/chat/completions",
"method": "POST",
"headers": {},
"body": {},
"arrival_time": time.time() - 0.5, # Simulate request arrived 0.5s ago
}
data["metadata"] = data.get("metadata", {})
return data
async def mock_pre_call_hook(user_api_key_dict, data, call_type):
return copy.deepcopy(data)
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
mock_proxy_logging_obj.pre_call_hook = AsyncMock(side_effect=mock_pre_call_hook)
monkeypatch.setattr(
litellm.proxy.common_request_processing,
"add_litellm_data_to_request",
mock_add_litellm_data_to_request,
)
mock_general_settings = {}
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_proxy_config = MagicMock(spec=ProxyConfig)
route_type = "acompletion"
(
returned_data,
logging_obj,
) = await processing_obj.common_processing_pre_call_logic(
request=mock_request,
general_settings=mock_general_settings,
user_api_key_dict=mock_user_api_key_dict,
proxy_logging_obj=mock_proxy_logging_obj,
proxy_config=mock_proxy_config,
route_type=route_type,
)
# Verify queue_time_seconds is set and non-negative
metadata = returned_data.get("metadata", {})
assert (
"queue_time_seconds" in metadata
), "queue_time_seconds should be set in metadata"
assert (
metadata["queue_time_seconds"] >= 0.5
), f"queue_time_seconds should be at least 0.5, got {metadata['queue_time_seconds']}"
@pytest.mark.asyncio
class TestCommonRequestProcessingHelpers:
async def consume_stream(self, streaming_response: StreamingResponse) -> list:
content = []
async for chunk_bytes in streaming_response.body_iterator:
content.append(chunk_bytes)
return content
@pytest.mark.parametrize(
"event_line, expected_code",
[
(
'data: {"error": {"code": 400, "message": "bad request"}}',
400,
), # Valid integer code
(
'data: {"error": {"code": "401", "message": "unauthorized"}}',
401,
), # Valid string-integer code
(
'data: {"error": {"code": "invalid_code", "message": "error"}}',
None,
), # Invalid string code
(
'data: {"error": {"code": 99, "message": "too low"}}',
None,
), # Integer code too low
(
'data: {"error": {"code": 600, "message": "too high"}}',
None,
), # Integer code too high
(
'data: {"id": "123", "content": "hello"}',
None,
), # Non-error SSE event
("data: [DONE]", None), # SSE [DONE] event
("data: ", None), # SSE empty data event
(
'data: {"error": {"code": 400',
None,
), # Malformed JSON
("id: 123", None), # Non-SSE event line
(
'data: {"error": {"message": "some error"}}',
None,
), # Error event without 'code' field
(
'data: {"error": {"code": null, "message": "code is null"}}',
None,
), # Error with null code
],
)
async def test_parse_event_data_for_error(self, event_line, expected_code):
assert await _parse_event_data_for_error(event_line) == expected_code
async def test_create_streaming_response_first_chunk_is_error(self):
"""
Test that when the first chunk is an error, a JSON error response is returned
instead of an SSE streaming response
"""
async def mock_generator():
yield 'data: {"error": {"code": 403, "message": "forbidden"}}\n\n'
yield 'data: {"content": "more data"}\n\n'
yield "data: [DONE]\n\n"
response = await create_response(mock_generator(), "text/event-stream", {})
# Should return JSONResponse instead of StreamingResponse
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_403_FORBIDDEN
# Verify the response is in standard JSON error format
import json
body = json.loads(response.body.decode())
assert "error" in body
assert body["error"]["code"] == 403
assert body["error"]["message"] == "forbidden"
async def test_create_streaming_response_first_chunk_not_error(self):
async def mock_generator():
yield 'data: {"content": "first part"}\n\n'
yield 'data: {"content": "second part"}\n\n'
yield "data: [DONE]\n\n"
response = await create_response(mock_generator(), "text/event-stream", {})
assert response.status_code == status.HTTP_200_OK
content = await self.consume_stream(response)
assert content == [
'data: {"content": "first part"}\n\n',
'data: {"content": "second part"}\n\n',
"data: [DONE]\n\n",
]
async def test_create_streaming_response_empty_generator(self):
async def mock_generator():
if False: # Never yields
yield
# Implicitly raises StopAsyncIteration
response = await create_response(mock_generator(), "text/event-stream", {})
assert response.status_code == status.HTTP_200_OK
content = await self.consume_stream(response)
assert content == []
async def test_create_streaming_response_generator_raises_stop_async_iteration_immediately(
self,
):
mock_gen = AsyncMock()
mock_gen.__anext__.side_effect = StopAsyncIteration
response = await create_response(mock_gen, "text/event-stream", {})
assert response.status_code == status.HTTP_200_OK
content = await self.consume_stream(response)
assert content == []
async def test_create_streaming_response_generator_raises_unexpected_exception(
self,
):
mock_gen = AsyncMock()
mock_gen.__anext__.side_effect = ValueError("Test error from generator")
response = await create_response(mock_gen, "text/event-stream", {})
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
content = await self.consume_stream(response)
# Streaming SSE error frame now mirrors ProxyException.to_dict() shape
# so streaming and non-streaming surfaces emit byte-identical errors.
expected_error_data = {
"error": {
"message": "Error processing stream start",
"type": "None",
"param": "None",
"code": str(status.HTTP_500_INTERNAL_SERVER_ERROR),
}
}
assert len(content) == 2
import json
assert content[0] == f"data: {json.dumps(expected_error_data)}\n\n"
assert content[1] == "data: [DONE]\n\n"
async def test_create_streaming_response_generator_raises_http_exception(
self,
):
"""
Test that when a generator raises HTTPException, the response preserves
the original status code instead of hardcoding 500.
"""
mock_gen = AsyncMock()
mock_gen.__anext__.side_effect = HTTPException(
status_code=400, detail="Content blocked by guardrail"
)
response = await create_response(mock_gen, "text/event-stream", {})
assert response.status_code == 400
content = await self.consume_stream(response)
import json
expected_error_data = {
"error": {
"message": "Content blocked by guardrail",
"type": "None",
"param": "None",
"code": "400",
}
}
assert len(content) == 2
assert content[0] == f"data: {json.dumps(expected_error_data)}\n\n"
assert content[1] == "data: [DONE]\n\n"
async def test_create_streaming_response_http_exception_dict_detail_bedrock_shape(
self,
):
"""
Bedrock-style dict detail (with the post-L3 shape) must be preserved as
structured `provider_specific_fields` in the SSE error frame, not stringified
into a Python-repr blob inside `error.message`. Regression for case
2026-04-10-internal-bedrock-guardrail-streaming-error.
"""
import json
mock_gen = AsyncMock()
mock_gen.__anext__.side_effect = HTTPException(
status_code=400,
detail={
"error": "Violated guardrail policy",
"bedrock_guardrail_response": "Sorry, the model cannot answer this question. Prompt is blocked",
"guardrailIdentifier": "amgllac6xf3r",
"guardrailVersion": "1",
"assessments": [
{
"policy": "sensitiveInformationPolicy",
"matches": [
{
"category": "piiEntities",
"type": "NAME",
"action": "BLOCKED",
"match": "Jack",
}
],
}
],
"guardrail_name": "bedrock-pii-guard",
"guardrail_mode": "post_call",
},
)
response = await create_response(mock_gen, "text/event-stream", {})
assert response.status_code == 400
content = await self.consume_stream(response)
assert len(content) == 2
assert content[1] == "data: [DONE]\n\n"
payload = json.loads(content[0][len("data: ") :].strip())
assert payload["error"]["message"] == "Violated guardrail policy"
assert payload["error"]["code"] == "400"
psf = payload["error"]["provider_specific_fields"]
assert psf["guardrail_name"] == "bedrock-pii-guard"
assert psf["guardrail_mode"] == "post_call"
assert psf["guardrailIdentifier"] == "amgllac6xf3r"
assert psf["assessments"][0]["policy"] == "sensitiveInformationPolicy"
assert psf["assessments"][0]["matches"][0]["type"] == "NAME"
async def test_create_streaming_response_http_exception_dict_detail_nested_error_shape(
self,
):
"""PANW Prisma AIRS-style nested `{"error": {"message": ...}}` detail must
extract `error.message` as the human-readable summary while preserving the
full payload."""
import json
mock_gen = AsyncMock()
mock_gen.__anext__.side_effect = HTTPException(
status_code=400,
detail={
"error": {
"message": "MCP request blocked: no rewritable argument field present",
"type": "guardrail_violation",
"code": "panw_prisma_airs_blocked",
}
},
)
response = await create_response(mock_gen, "text/event-stream", {})
content = await self.consume_stream(response)
payload = json.loads(content[0][len("data: ") :].strip())
assert (
payload["error"]["message"]
== "MCP request blocked: no rewritable argument field present"
)
assert (
payload["error"]["provider_specific_fields"]["error"]["code"]
== "panw_prisma_airs_blocked"
)
async def test_serialize_http_exception_detail_helper(self):
"""Direct unit coverage for the L1 helper across all branches."""
from litellm.proxy.common_request_processing import (
_serialize_http_exception_detail,
)
import json as _json
assert _serialize_http_exception_detail("plain") == ("plain", None)
msg, fields = _serialize_http_exception_detail(
{"error": "Violated", "extra": "x"}
)
assert msg == "Violated"
assert fields == {"error": "Violated", "extra": "x"}
msg, fields = _serialize_http_exception_detail(
{"error": {"message": "blocked", "code": "x"}}
)
assert msg == "blocked"
assert fields == {"error": {"message": "blocked", "code": "x"}}
msg, fields = _serialize_http_exception_detail({"message": "top-level"})
assert msg == "top-level"
assert fields == {"message": "top-level"}
msg, fields = _serialize_http_exception_detail({"weird": ["a", "b"]})
assert msg == _json.dumps({"weird": ["a", "b"]})
assert fields == {"weird": ["a", "b"]}
assert _serialize_http_exception_detail(42) == ("42", None)
async def test_create_streaming_response_first_chunk_error_string_code(self):
"""
Test that when the first chunk contains a string error code, a JSON error response is returned
"""
async def mock_generator():
yield 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n'
yield "data: [DONE]\n\n"
response = await create_response(mock_generator(), "text/event-stream", {})
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
# Verify the response is in standard JSON error format
import json
body = json.loads(response.body.decode())
assert "error" in body
assert body["error"]["code"] == "429"
assert body["error"]["message"] == "too many requests"
async def test_create_streaming_response_custom_headers(self):
async def mock_generator():
yield 'data: {"content": "data"}\n\n'
yield "data: [DONE]\n\n"
custom_headers = {"X-Custom-Header": "TestValue"}
response = await create_response(
mock_generator(), "text/event-stream", custom_headers
)
assert response.headers["x-custom-header"] == "TestValue"
async def test_create_streaming_response_non_default_status_code(self):
async def mock_generator():
yield 'data: {"content": "data"}\n\n'
yield "data: [DONE]\n\n"
response = await create_response(
mock_generator(),
"text/event-stream",
{},
default_status_code=status.HTTP_201_CREATED,
)
assert response.status_code == status.HTTP_201_CREATED
content = await self.consume_stream(response)
assert content == [
'data: {"content": "data"}\n\n',
"data: [DONE]\n\n",
]
async def test_create_streaming_response_first_chunk_is_done(self):
async def mock_generator():
yield "data: [DONE]\n\n"
response = await create_response(mock_generator(), "text/event-stream", {})
assert response.status_code == status.HTTP_200_OK # Default status
content = await self.consume_stream(response)
assert content == ["data: [DONE]\n\n"]
async def test_create_streaming_response_first_chunk_is_empty_data(self):
async def mock_generator():
yield "data: \n\n"
yield 'data: {"content": "actual data"}\n\n'
yield "data: [DONE]\n\n"
response = await create_response(mock_generator(), "text/event-stream", {})
assert response.status_code == status.HTTP_200_OK # Default status
content = await self.consume_stream(response)
assert content == [
"data: \n\n",
'data: {"content": "actual data"}\n\n',
"data: [DONE]\n\n",
]
async def test_create_streaming_response_all_chunks_have_dd_trace(self):
"""Test that all stream chunks are wrapped with dd trace at the streaming generator level"""
from unittest.mock import patch
# Create a mock tracer
mock_tracer = MagicMock()
mock_span = MagicMock()
mock_tracer.trace.return_value.__enter__.return_value = mock_span
mock_tracer.trace.return_value.__exit__.return_value = None
# Mock generator with multiple chunks
async def mock_generator():
yield 'data: {"content": "chunk 1"}\n\n'
yield 'data: {"content": "chunk 2"}\n\n'
yield 'data: {"content": "chunk 3"}\n\n'
yield "data: [DONE]\n\n"
# Patch the tracer in the common_request_processing module. The
# per-chunk span is gated on _DD_STREAMING_TRACE_ENABLED (resolved at
# import from the real tracer, a NullTracer by default), so enable it
# explicitly to exercise the tracing path.
with (
patch("litellm.proxy.common_request_processing.tracer", mock_tracer),
patch(
"litellm.proxy.common_request_processing._DD_STREAMING_TRACE_ENABLED",
True,
),
):
response = await create_response(mock_generator(), "text/event-stream", {})
assert response.status_code == 200
# Consume the stream to trigger the tracer calls
content = await self.consume_stream(response)
# Verify all chunks are present
assert len(content) == 4
assert content[0] == 'data: {"content": "chunk 1"}\n\n'
assert content[1] == 'data: {"content": "chunk 2"}\n\n'
assert content[2] == 'data: {"content": "chunk 3"}\n\n'
assert content[3] == "data: [DONE]\n\n"
# Verify that tracer.trace was called for each chunk (4 chunks total)
assert mock_tracer.trace.call_count == 4
# Verify that each call was made with the correct operation name
actual_calls = mock_tracer.trace.call_args_list
assert len(actual_calls) == 4
for i, call in enumerate(actual_calls):
args, kwargs = call
assert (
args[0] == "streaming.chunk.yield"
), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}"
async def test_create_streaming_response_skips_dd_trace_when_disabled(self):
"""When DD tracing is disabled (the default), the per-chunk span
context manager is skipped entirely but all chunks still stream."""
from unittest.mock import patch
mock_tracer = MagicMock()
async def mock_generator():
yield 'data: {"content": "chunk 1"}\n\n'
yield 'data: {"content": "chunk 2"}\n\n'
yield "data: [DONE]\n\n"
with (
patch("litellm.proxy.common_request_processing.tracer", mock_tracer),
patch(
"litellm.proxy.common_request_processing._DD_STREAMING_TRACE_ENABLED",
False,
),
):
response = await create_response(mock_generator(), "text/event-stream", {})
assert response.status_code == 200
content = await self.consume_stream(response)
# All chunks stream through unchanged ...
assert content == [
'data: {"content": "chunk 1"}\n\n',
'data: {"content": "chunk 2"}\n\n',
"data: [DONE]\n\n",
]
# ... but no per-chunk span was created.
assert mock_tracer.trace.call_count == 0
async def test_create_streaming_response_dd_trace_with_error_chunk(self):
"""
Test that when the first chunk contains an error, JSONResponse is returned
and tracing is not triggered (since it's not a streaming response)
"""
from unittest.mock import patch
# Create a mock tracer
mock_tracer = MagicMock()
mock_span = MagicMock()
mock_tracer.trace.return_value.__enter__.return_value = mock_span
mock_tracer.trace.return_value.__exit__.return_value = None
# Mock generator with error in first chunk
async def mock_generator():
yield 'data: {"error": {"code": 400, "message": "bad request"}}\n\n'
yield 'data: {"content": "chunk after error"}\n\n'
yield "data: [DONE]\n\n"
# Patch the tracer in the common_request_processing module
with patch("litellm.proxy.common_request_processing.tracer", mock_tracer):
response = await create_response(mock_generator(), "text/event-stream", {})
# Should return JSONResponse instead of StreamingResponse
assert isinstance(response, JSONResponse)
assert response.status_code == 400
# Verify the response is in standard JSON error format
import json
body = json.loads(response.body.decode())
assert "error" in body
assert body["error"]["code"] == 400
assert body["error"]["message"] == "bad request"
# Since JSONResponse is returned instead of StreamingResponse, streaming tracing should not be triggered
# tracer.trace should not be called
assert mock_tracer.trace.call_count == 0
class TestExtractErrorFromSSEChunk:
"""Tests for _extract_error_from_sse_chunk function"""
def test_extract_error_from_sse_chunk_with_valid_error(self):
"""Test extracting error information from a standard SSE chunk"""
chunk = 'data: {"error": {"code": 403, "message": "forbidden", "type": "auth_error", "param": "api_key"}}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["code"] == 403
assert error["message"] == "forbidden"
assert error["type"] == "auth_error"
assert error["param"] == "api_key"
def test_extract_error_from_sse_chunk_with_string_code(self):
"""Test error code as string type"""
chunk = 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["code"] == "429"
assert error["message"] == "too many requests"
def test_extract_error_from_sse_chunk_with_bytes(self):
"""Test input as bytes type"""
chunk = b'data: {"error": {"code": 500, "message": "internal error"}}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["code"] == 500
assert error["message"] == "internal error"
def test_extract_error_from_sse_chunk_with_done(self):
"""Test [DONE] marker should return default error"""
chunk = "data: [DONE]\n\n"
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "Unknown error"
assert error["type"] == "internal_server_error"
assert error["code"] == "500"
assert error["param"] is None
def test_extract_error_from_sse_chunk_without_error_field(self):
"""Test missing error field should return default error"""
chunk = 'data: {"content": "some content"}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "Unknown error"
assert error["type"] == "internal_server_error"
assert error["code"] == "500"
def test_extract_error_from_sse_chunk_with_invalid_json(self):
"""Test invalid JSON should return default error"""
chunk = "data: {invalid json}\n\n"
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "Unknown error"
assert error["type"] == "internal_server_error"
assert error["code"] == "500"
def test_extract_error_from_sse_chunk_without_data_prefix(self):
"""Test missing 'data:' prefix should return default error"""
chunk = '{"error": {"code": 400, "message": "bad request"}}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "Unknown error"
assert error["type"] == "internal_server_error"
assert error["code"] == "500"
def test_extract_error_from_sse_chunk_with_empty_string(self):
"""Test empty string should return default error"""
chunk = ""
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "Unknown error"
assert error["type"] == "internal_server_error"
assert error["code"] == "500"
def test_extract_error_from_sse_chunk_with_minimal_error(self):
"""Test minimal error object"""
chunk = 'data: {"error": {"message": "error occurred"}}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "error occurred"
# Other fields should be obtained from the original error object (if exists)
class TestOverrideOpenAIResponseModel:
"""Tests for _override_openai_response_model function"""
def test_override_model_preserves_fallback_model_when_fallback_occurred_object(
self,
):
"""
Test that when a fallback occurred (x-litellm-attempted-fallbacks > 0),
the actual model used (fallback model) is preserved instead of being
overridden with the requested model.
This is the regression test to ensure the model being called is properly
displayed when a fallback happens.
"""
requested_model = "gpt-4"
fallback_model = "gpt-3.5-turbo"
# Create a mock object response with fallback model
# _hidden_params is an attribute (not a dict key) accessed via getattr
response_obj = MagicMock()
response_obj.model = fallback_model
response_obj._hidden_params = {
"additional_headers": {"x-litellm-attempted-fallbacks": 1}
}
# Call the function - should preserve fallback model
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
# Verify the model was NOT overridden - should still be the fallback model
assert response_obj.model == fallback_model
assert response_obj.model != requested_model
def test_override_model_preserves_fallback_model_multiple_fallbacks(self):
"""
Test that when multiple fallbacks occurred, the actual model used
(fallback model) is preserved.
"""
requested_model = "gpt-4"
fallback_model = "claude-haiku-4-5-20251001"
# Create a mock object response with fallback model
response_obj = MagicMock()
response_obj.model = fallback_model
response_obj._hidden_params = {
"additional_headers": {
"x-litellm-attempted-fallbacks": 2 # Multiple fallbacks
}
}
# Call the function - should preserve fallback model
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
# Verify the model was NOT overridden - should still be the fallback model
assert response_obj.model == fallback_model
assert response_obj.model != requested_model
def test_override_model_overrides_when_no_fallback_dict(self):
"""
Test that when no fallback occurred, the model is overridden
to match the requested model (dict response).
"""
requested_model = "gpt-4"
downstream_model = "gpt-3.5-turbo"
# Create a dict response without fallback
# For dict responses, _hidden_params won't be found via getattr,
# so the fallback check won't trigger and model will be overridden
response_obj = {"model": downstream_model}
# Call the function - should override to requested model
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
# Verify the model WAS overridden to requested model
assert response_obj["model"] == requested_model
def test_override_model_overrides_when_no_fallback_object(self):
"""
Test that when no fallback occurred (object response), the model is overridden
to match the requested model.
"""
requested_model = "gpt-4"
downstream_model = "gpt-3.5-turbo"
# Create a mock object response without fallback
response_obj = MagicMock()
response_obj.model = downstream_model
response_obj._hidden_params = {
"additional_headers": {} # No attempted_fallbacks header
}
# Call the function - should override to requested model
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
# Verify the model WAS overridden to requested model
assert response_obj.model == requested_model
def test_override_model_overrides_when_attempted_fallbacks_is_zero(self):
"""
Test that when attempted_fallbacks is 0 (no fallback occurred),
the model is overridden to match the requested model.
"""
requested_model = "gpt-4"
downstream_model = "gpt-3.5-turbo"
# Create a mock object response
response_obj = MagicMock()
response_obj.model = downstream_model
response_obj._hidden_params = {
"additional_headers": {
"x-litellm-attempted-fallbacks": 0 # Zero means no fallback occurred
}
}
# Call the function - should override to requested model
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
# Verify the model WAS overridden to requested model
assert response_obj.model == requested_model
def test_override_model_overrides_when_attempted_fallbacks_is_none(self):
"""
Test that when attempted_fallbacks is None (not set),
the model is overridden to match the requested model.
"""
requested_model = "gpt-4"
downstream_model = "gpt-3.5-turbo"
# Create a mock object response
response_obj = MagicMock()
response_obj.model = downstream_model
response_obj._hidden_params = {
"additional_headers": {"x-litellm-attempted-fallbacks": None}
}
# Call the function - should override to requested model
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
# Verify the model WAS overridden to requested model
assert response_obj.model == requested_model
def test_override_model_no_hidden_params(self):
"""
Test that when _hidden_params is not present, the model is overridden
to match the requested model.
"""
requested_model = "gpt-4"
downstream_model = "gpt-3.5-turbo"
# Create a mock object response without _hidden_params
response_obj = MagicMock()
response_obj.model = downstream_model
# Don't set _hidden_params - getattr will return {}
# Call the function - should override to requested model
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
# Verify the model WAS overridden to requested model
assert response_obj.model == requested_model
def test_override_model_no_requested_model(self):
"""
Test that when requested_model is None or empty, the function returns early
without modifying the response.
"""
fallback_model = "gpt-3.5-turbo"
# Create a mock object response
response_obj = MagicMock()
response_obj.model = fallback_model
response_obj._hidden_params = {
"additional_headers": {"x-litellm-attempted-fallbacks": 1}
}
# Call the function with None requested_model
_override_openai_response_model(
response_obj=response_obj,
requested_model=None,
log_context="test_context",
)
# Verify the model was not changed
assert response_obj.model == fallback_model
# Call with empty string
_override_openai_response_model(
response_obj=response_obj,
requested_model="",
log_context="test_context",
)
# Verify the model was not changed
assert response_obj.model == fallback_model
def test_override_model_preserves_azure_model_router_actual_model(self):
"""
Test that when the requested model is an Azure Model Router, the actual
model used (returned in the response) is preserved instead of being
overridden.
"""
requested_model = "azure_ai/model_router"
actual_model_used = "azure_ai/gpt-5-nano-2025-08-07"
response_obj = MagicMock()
response_obj.model = actual_model_used
response_obj._hidden_params = {"additional_headers": {}}
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
assert response_obj.model == actual_model_used
assert response_obj.model != requested_model
def test_override_model_preserves_azure_model_router_with_deployment_name(self):
"""
Test that Azure Model Router with deployment name pattern also preserves
the actual model used.
"""
requested_model = "azure_ai/model_router/my-deployment"
actual_model_used = "azure_ai/gpt-4.1-nano-2025-04-14"
response_obj = MagicMock()
response_obj.model = actual_model_used
response_obj._hidden_params = {"additional_headers": {}}
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
assert response_obj.model == actual_model_used
assert response_obj.model != requested_model
def test_override_model_preserves_azure_model_router_with_hyphen(self):
"""
Test that Azure Model Router with hyphen pattern (model-router) also preserves
the actual model used.
"""
requested_model = "azure_ai/model-router"
actual_model_used = "azure_ai/gpt-5-nano-2025-08-07"
response_obj = MagicMock()
response_obj.model = actual_model_used
response_obj._hidden_params = {"additional_headers": {}}
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
assert response_obj.model == actual_model_used
assert response_obj.model != requested_model
def test_override_model_uses_winning_model_for_fastest_response(self):
"""
Test that when fastest_response batch completion is used with a
comma-separated model list, the response model is set to the winning
model's group name (not the comma-separated list).
"""
requested_model = "openai/gpt-4o,gemini/gemini-2.5-flash"
winning_model_group = "gemini/gemini-2.5-flash"
downstream_model = "gemini-2.5-flash"
response_obj = MagicMock()
response_obj.model = downstream_model
response_obj._hidden_params = {
"fastest_response_batch_completion": True,
"additional_headers": {
"x-litellm-model-group": winning_model_group,
},
}
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
assert response_obj.model == winning_model_group
assert response_obj.model != requested_model
def test_override_model_preserves_response_when_fastest_response_no_model_group(
self,
):
"""
Test that when fastest_response is set but no model group header is
available, the actual downstream model is preserved.
"""
requested_model = "openai/gpt-4o,gemini/gemini-2.5-flash"
downstream_model = "gpt-4o-2024-08-06"
response_obj = MagicMock()
response_obj.model = downstream_model
response_obj._hidden_params = {
"fastest_response_batch_completion": True,
"additional_headers": {},
}
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
assert response_obj.model == downstream_model
def test_override_model_normal_when_fastest_response_not_set(self):
"""
Test that when fastest_response_batch_completion is not set, the
normal override behavior applies (model is set to requested_model).
"""
requested_model = "openai/gpt-4o"
downstream_model = "gpt-4o-2024-08-06"
response_obj = MagicMock()
response_obj.model = downstream_model
response_obj._hidden_params = {
"additional_headers": {
"x-litellm-model-group": "openai/gpt-4o",
},
}
_override_openai_response_model(
response_obj=response_obj,
requested_model=requested_model,
log_context="test_context",
)
assert response_obj.model == requested_model
class TestIsAzureModelRouterRequest:
"""Tests for _is_azure_model_router_request helper"""
def test_detects_model_router_with_underscore(self):
assert _is_azure_model_router_request("azure_ai/model_router") is True
assert (
_is_azure_model_router_request("azure_ai/model_router/my-deployment")
is True
)
def test_detects_model_router_with_hyphen(self):
assert _is_azure_model_router_request("azure_ai/model-router") is True
assert _is_azure_model_router_request("model-router") is True
def test_rejects_regular_models(self):
assert _is_azure_model_router_request("azure_ai/gpt-4") is False
assert _is_azure_model_router_request("gpt-4") is False
assert _is_azure_model_router_request("openai/gpt-3.5-turbo") is False
class TestStreamingOverheadHeader:
"""
Tests that x-litellm-overhead-duration-ms is emitted in streaming responses.
Regression tests for: streaming requests not including overhead header.
"""
def test_get_custom_headers_includes_overhead_when_set(self):
"""
get_custom_headers() returns x-litellm-overhead-duration-ms
when litellm_overhead_time_ms is in hidden_params.
"""
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_user_api_key_dict.tpm_limit = None
mock_user_api_key_dict.rpm_limit = None
mock_user_api_key_dict.max_budget = None
mock_user_api_key_dict.spend = 0.0
mock_user_api_key_dict.allowed_model_region = None
hidden_params = {
"litellm_overhead_time_ms": 42.5,
"_response_ms": 500.0,
"model_id": "test-model-id",
"api_base": "https://api.openai.com",
}
headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id",
model_id="test-model-id",
cache_key="",
api_base="https://api.openai.com",
version="1.0.0",
response_cost=0.001,
model_region="",
hidden_params=hidden_params,
)
assert "x-litellm-overhead-duration-ms" in headers
assert headers["x-litellm-overhead-duration-ms"] == "42.5"
def test_get_custom_headers_omits_overhead_when_none(self):
"""
get_custom_headers() omits x-litellm-overhead-duration-ms
when litellm_overhead_time_ms is not in hidden_params.
"""
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_user_api_key_dict.tpm_limit = None
mock_user_api_key_dict.rpm_limit = None
mock_user_api_key_dict.max_budget = None
mock_user_api_key_dict.spend = 0.0
mock_user_api_key_dict.allowed_model_region = None
hidden_params = {
"_response_ms": 500.0,
"model_id": "test-model-id",
}
headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id",
model_id="test-model-id",
cache_key="",
api_base="https://api.openai.com",
version="1.0.0",
response_cost=0.001,
model_region="",
hidden_params=hidden_params,
)
# Should be absent (None gets filtered by exclude_values)
assert "x-litellm-overhead-duration-ms" not in headers
def test_update_response_metadata_sets_overhead_on_stream_wrapper(self):
"""
update_response_metadata() sets litellm_overhead_time_ms on
a streaming response's _hidden_params when llm_api_duration_ms is available.
"""
from litellm.litellm_core_utils.llm_response_utils.response_metadata import (
update_response_metadata,
)
# Mock the logging object with llm_api_duration_ms set
mock_logging_obj = MagicMock()
mock_logging_obj.model_call_details = {
"llm_api_duration_ms": 200.0,
"litellm_params": {},
}
mock_logging_obj.caching_details = None
mock_logging_obj.callback_duration_ms = None
mock_logging_obj.litellm_call_id = "test-call-id"
mock_logging_obj._response_cost_calculator = MagicMock(return_value=0.001)
# Simulate a streaming result object with _hidden_params (like CustomStreamWrapper)
stream_result = MagicMock()
stream_result._hidden_params = {
"model_id": "test-model-id",
"api_base": "https://api.openai.com",
"additional_headers": {},
}
start_time = datetime.datetime.now() - datetime.timedelta(milliseconds=300)
end_time = datetime.datetime.now()
update_response_metadata(
result=stream_result,
logging_obj=mock_logging_obj,
model="gpt-4o",
kwargs={},
start_time=start_time,
end_time=end_time,
)
assert "litellm_overhead_time_ms" in stream_result._hidden_params
overhead = stream_result._hidden_params["litellm_overhead_time_ms"]
assert overhead is not None
assert isinstance(overhead, float)
# overhead = total_response_ms (~300ms) - llm_api_duration_ms (200ms) = ~100ms
assert overhead > 0
@pytest.mark.asyncio
async def test_streaming_response_includes_overhead_header(self):
"""
StreamingResponse returned by create_response() includes
x-litellm-overhead-duration-ms in its headers.
"""
async def mock_generator() -> AsyncGenerator[str, None]:
yield 'data: {"id":"chatcmpl-test","choices":[{"delta":{"content":"hi"}}]}\n\n'
yield "data: [DONE]\n\n"
headers = {
"x-litellm-overhead-duration-ms": "42.5",
"x-litellm-call-id": "test-call-id",
"x-litellm-model-id": "test-model-id",
}
response = await create_response(
generator=mock_generator(),
media_type="text/event-stream",
headers=headers,
)
assert isinstance(response, StreamingResponse)
assert response.headers.get("x-litellm-overhead-duration-ms") == "42.5"
def test_streaming_overhead_header_in_custom_headers_from_stream_hidden_params(
self,
):
"""
Verifies that when get_custom_headers() is called with a streaming
response's hidden_params (containing litellm_overhead_time_ms),
the x-litellm-overhead-duration-ms header is correctly populated.
This tests the critical path: update_response_metadata sets the value
→ get_custom_headers reads it → StreamingResponse header is set.
"""
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_user_api_key_dict.tpm_limit = None
mock_user_api_key_dict.rpm_limit = None
mock_user_api_key_dict.max_budget = None
mock_user_api_key_dict.spend = 0.0
mock_user_api_key_dict.allowed_model_region = None
# This is what CustomStreamWrapper._hidden_params looks like after
# update_response_metadata() has been called on it
hidden_params = {
"model_id": "openai-gpt4o-deployment",
"api_base": "https://api.openai.com",
"additional_headers": {},
"litellm_overhead_time_ms": 55.3, # set by update_response_metadata
"_response_ms": 280.0,
"litellm_call_id": "test-call-id",
"response_cost": 0.002,
"cache_key": None,
"fastest_response_batch_completion": None,
"callback_duration_ms": None,
}
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=mock_user_api_key_dict,
call_id="test-call-id",
model_id=hidden_params.get("model_id"),
cache_key=hidden_params.get("cache_key") or "",
api_base=hidden_params.get("api_base") or "",
version="1.0.0",
response_cost=hidden_params.get("response_cost"),
model_region="",
hidden_params=hidden_params,
)
# The overhead header must be present and correct
assert "x-litellm-overhead-duration-ms" in custom_headers, (
"x-litellm-overhead-duration-ms header must be emitted during streaming. "
"It was missing — this is the streaming overhead header regression."
)
assert custom_headers["x-litellm-overhead-duration-ms"] == "55.3"
class TestDDSpanTaggerTagRequest:
"""Tests for DDSpanTagger.tag_request - key/model DD span tagging."""
def _make_user_api_key_dict(self, key_alias=None, token=None):
from litellm.proxy._types import UserAPIKeyAuth
d = UserAPIKeyAuth()
d.key_alias = key_alias
d.token = token
return d
def test_tags_key_alias_and_model(self):
"""key_alias and requested_model are set on the span when present."""
user_key = self._make_user_api_key_dict(
key_alias="my-prod-key", token="hashed123"
)
with patch("litellm.proxy.dd_span_tagger.set_active_span_tag") as mock_set_tag:
DDSpanTagger.tag_request(
user_api_key_dict=user_key,
requested_model="gpt-4o",
)
mock_set_tag.assert_any_call("litellm.key_alias", "my-prod-key")
mock_set_tag.assert_any_call("litellm.key_hash", "hashed123")
mock_set_tag.assert_any_call("litellm.requested_model", "gpt-4o")
def test_no_tags_when_key_absent(self):
"""No key tags are set when key_alias and token are None (e.g. 401 path)."""
user_key = self._make_user_api_key_dict(key_alias=None, token=None)
with patch("litellm.proxy.dd_span_tagger.set_active_span_tag") as mock_set_tag:
DDSpanTagger.tag_request(
user_api_key_dict=user_key,
requested_model=None,
)
mock_set_tag.assert_not_called()
def test_only_model_tagged_when_no_key_info(self):
"""requested_model is tagged even when there's no key info."""
user_key = self._make_user_api_key_dict(key_alias=None, token=None)
with patch("litellm.proxy.dd_span_tagger.set_active_span_tag") as mock_set_tag:
DDSpanTagger.tag_request(
user_api_key_dict=user_key,
requested_model="claude-3-5-sonnet",
)
mock_set_tag.assert_called_once_with(
"litellm.requested_model", "claude-3-5-sonnet"
)
class TestHasAttributeErrorInChain:
"""Tests for _has_attribute_error_in_chain helper."""
def test_direct_attribute_error(self):
exc = AttributeError("'str' object has no attribute 'get'")
assert _has_attribute_error_in_chain(exc) is True
def test_no_attribute_error(self):
exc = ValueError("some other error")
assert _has_attribute_error_in_chain(exc) is False
def test_attribute_error_in_cause(self):
inner = AttributeError("bad attribute")
outer = RuntimeError("wrapper")
outer.__cause__ = inner
assert _has_attribute_error_in_chain(outer) is True
def test_attribute_error_in_context(self):
inner = AttributeError("bad attribute")
outer = RuntimeError("wrapper")
outer.__context__ = inner
assert _has_attribute_error_in_chain(outer) is True
def test_attribute_error_in_original_exception(self):
inner = AttributeError("bad attribute")
outer = RuntimeError("wrapper")
outer.original_exception = inner # type: ignore
assert _has_attribute_error_in_chain(outer) is True
def test_attribute_error_nested_two_levels(self):
"""Simulates the real failure: AttributeError -> OpenAIException -> APIConnectionError."""
attr_err = AttributeError("'str' object has no attribute 'get'")
mid = Exception("OpenAIException wrapper")
mid.__context__ = attr_err
outer = Exception("APIConnectionError wrapper")
outer.__context__ = mid
assert _has_attribute_error_in_chain(outer) is True
def test_depth_limit_prevents_infinite_loop(self):
"""Ensure circular references don't cause infinite recursion."""
exc_a = RuntimeError("a")
exc_b = RuntimeError("b")
exc_a.__context__ = exc_b
exc_b.__context__ = exc_a # circular
assert _has_attribute_error_in_chain(exc_a) is False
@pytest.mark.asyncio
class TestHandleLLMApiExceptionDictDetail:
"""
Coverage for `_handle_llm_api_exception` HTTPException branch (Site 2).
Regression for case 2026-04-10-internal-bedrock-guardrail-streaming-error:
dict-detail HTTPExceptions raised by guardrails must round-trip cleanly
through ProxyException instead of being str()-mangled into a Python repr.
"""
async def _invoke(self, exc: Exception):
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
processor = ProxyBaseLLMRequestProcessing(data={})
user_api_key_dict = UserAPIKeyAuth(api_key="sk-test")
proxy_logging_obj = MagicMock()
proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None)
proxy_logging_obj.post_call_response_headers_hook = AsyncMock(return_value={})
try:
await processor._handle_llm_api_exception(
e=exc,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
)
except ProxyException as raised:
return raised
raise AssertionError("ProxyException was not raised")
async def test_dict_detail_bedrock_shape_preserved(self):
exc = HTTPException(
status_code=400,
detail={
"error": "Violated guardrail policy",
"bedrock_guardrail_response": "...",
"guardrail_name": "bedrock-pii-guard",
},
)
proxy_exc = await self._invoke(exc)
assert proxy_exc.message == "Violated guardrail policy"
assert (
proxy_exc.provider_specific_fields["guardrail_name"] == "bedrock-pii-guard"
)
# No Python repr leakage of the dict into the message field.
assert "{'error':" not in proxy_exc.message
async def test_string_detail_unchanged(self):
exc = HTTPException(status_code=400, detail="Content blocked by guardrail")
proxy_exc = await self._invoke(exc)
assert proxy_exc.message == "Content blocked by guardrail"
assert proxy_exc.provider_specific_fields is None
class TestAsyncStreamingDataGeneratorFastPath:
"""Fast/slow path branching in async_streaming_data_generator."""
@staticmethod
async def _aiter(items):
for item in items:
yield item
@pytest.mark.asyncio
async def test_fast_path_skips_per_chunk_hook(self, monkeypatch):
"""With no callbacks/guardrails/cost-injection, chunks pass through
unchanged and the per-chunk hook is NOT awaited."""
monkeypatch.setattr(litellm, "callbacks", [])
ProxyLogging._callback_capabilities_cache.clear()
proxy_logging_obj = ProxyLogging(user_api_key_cache=MagicMock())
hook_spy = AsyncMock(side_effect=lambda **kw: kw["response"])
monkeypatch.setattr(
proxy_logging_obj, "async_post_call_streaming_hook", hook_spy
)
chunks = [b"event: a\ndata: {}\n\n", b"event: b\ndata: {}\n\n"]
out = [
c
async for c in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
response=self._aiter(chunks),
user_api_key_dict=MagicMock(spec=UserAPIKeyAuth),
request_data={"model": "claude-x"},
proxy_logging_obj=proxy_logging_obj,
serialize_chunk=ProxyBaseLLMRequestProcessing.return_sse_chunk,
serialize_error=lambda e: "data: error\n\n",
)
]
assert out == chunks # bytes pass through return_sse_chunk untouched
hook_spy.assert_not_awaited()
@pytest.mark.asyncio
async def test_slow_path_runs_per_chunk_hook(self, monkeypatch):
"""A callback that overrides async_post_call_streaming_hook forces the
slow path and the per-chunk hook is invoked."""
class _StreamingCb(CustomLogger):
async def async_post_call_streaming_hook(self, user_api_key_dict, response):
return response
cb = _StreamingCb()
monkeypatch.setattr(litellm, "callbacks", [cb])
ProxyLogging._callback_capabilities_cache.clear()
proxy_logging_obj = ProxyLogging(user_api_key_cache=MagicMock())
hook_spy = AsyncMock(side_effect=lambda **kw: kw["response"])
monkeypatch.setattr(
proxy_logging_obj, "async_post_call_streaming_hook", hook_spy
)
out = [
c
async for c in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
response=self._aiter([{"type": "message_stop"}]),
user_api_key_dict=MagicMock(spec=UserAPIKeyAuth),
request_data={"model": "claude-x"},
proxy_logging_obj=proxy_logging_obj,
serialize_chunk=ProxyBaseLLMRequestProcessing.return_sse_chunk,
serialize_error=lambda e: "data: error\n\n",
)
]
assert len(out) == 1
hook_spy.assert_awaited_once()
ProxyLogging._callback_capabilities_cache.clear()