import copy from unittest.mock import AsyncMock, MagicMock import pytest from fastapi import Request, status from fastapi.responses import JSONResponse, StreamingResponse import litellm from litellm._uuid import uuid from litellm.integrations.opentelemetry import UserAPIKeyAuth from litellm.proxy.common_request_processing import ( ProxyBaseLLMRequestProcessing, ProxyConfig, _extract_error_from_sse_chunk, _get_cost_breakdown_from_logging_obj, _parse_event_data_for_error, create_response, ) from litellm.proxy.utils import ProxyLogging class TestProxyBaseLLMRequestProcessing: @pytest.mark.asyncio async def test_common_processing_pre_call_logic_pre_call_hook_receives_litellm_call_id( self, monkeypatch ): processing_obj = ProxyBaseLLMRequestProcessing(data={}) mock_request = MagicMock(spec=Request) mock_request.headers = {} async def mock_add_litellm_data_to_request(*args, **kwargs): return {} async def mock_common_processing_pre_call_logic( user_api_key_dict, data, call_type ): data_copy = copy.deepcopy(data) return data_copy mock_proxy_logging_obj = MagicMock(spec=ProxyLogging) mock_proxy_logging_obj.pre_call_hook = AsyncMock( side_effect=mock_common_processing_pre_call_logic ) monkeypatch.setattr( litellm.proxy.common_request_processing, "add_litellm_data_to_request", mock_add_litellm_data_to_request, ) mock_general_settings = {} mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) mock_proxy_config = MagicMock(spec=ProxyConfig) route_type = "acompletion" # Call the actual method. ( returned_data, logging_obj, ) = await processing_obj.common_processing_pre_call_logic( request=mock_request, general_settings=mock_general_settings, user_api_key_dict=mock_user_api_key_dict, proxy_logging_obj=mock_proxy_logging_obj, proxy_config=mock_proxy_config, route_type=route_type, ) mock_proxy_logging_obj.pre_call_hook.assert_called_once() _, call_kwargs = mock_proxy_logging_obj.pre_call_hook.call_args data_passed = call_kwargs.get("data", {}) assert "litellm_call_id" in data_passed try: uuid.UUID(data_passed["litellm_call_id"]) except ValueError: pytest.fail("litellm_call_id is not a valid UUID") assert data_passed["litellm_call_id"] == returned_data["litellm_call_id"] @pytest.mark.asyncio async def test_should_apply_hierarchical_router_settings_to_user_config( self, monkeypatch ): processing_obj = ProxyBaseLLMRequestProcessing(data={}) mock_request = MagicMock(spec=Request) mock_request.headers = {} async def mock_add_litellm_data_to_request(*args, **kwargs): return {} async def mock_common_processing_pre_call_logic( user_api_key_dict, data, call_type ): data_copy = copy.deepcopy(data) return data_copy mock_proxy_logging_obj = MagicMock(spec=ProxyLogging) mock_proxy_logging_obj.pre_call_hook = AsyncMock( side_effect=mock_common_processing_pre_call_logic ) monkeypatch.setattr( litellm.proxy.common_request_processing, "add_litellm_data_to_request", mock_add_litellm_data_to_request, ) mock_general_settings = {} mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) mock_proxy_config = MagicMock(spec=ProxyConfig) mock_router_settings = { "routing_strategy": "least-busy", "timeout": 30.0, "num_retries": 3, } mock_proxy_config._get_hierarchical_router_settings = AsyncMock( return_value=mock_router_settings ) mock_model_list = [ {"model_name": "gpt-3.5-turbo", "litellm_params": {"model": "gpt-3.5-turbo"}}, {"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}}, ] mock_llm_router = MagicMock() mock_llm_router.get_model_list = MagicMock(return_value=mock_model_list) mock_prisma_client = MagicMock() monkeypatch.setattr( "litellm.proxy.proxy_server.prisma_client", mock_prisma_client, ) route_type = "acompletion" returned_data, logging_obj = await processing_obj.common_processing_pre_call_logic( request=mock_request, general_settings=mock_general_settings, user_api_key_dict=mock_user_api_key_dict, proxy_logging_obj=mock_proxy_logging_obj, proxy_config=mock_proxy_config, route_type=route_type, llm_router=mock_llm_router, ) mock_proxy_config._get_hierarchical_router_settings.assert_called_once_with( user_api_key_dict=mock_user_api_key_dict, prisma_client=mock_prisma_client, ) mock_llm_router.get_model_list.assert_called_once() assert "user_config" in returned_data user_config = returned_data["user_config"] assert user_config["model_list"] == mock_model_list assert user_config["routing_strategy"] == "least-busy" assert user_config["timeout"] == 30.0 assert user_config["num_retries"] == 3 @pytest.mark.asyncio async def test_stream_timeout_header_processing(self): """ Test that x-litellm-stream-timeout header gets processed and added to request data as stream_timeout. """ from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup # Test with stream timeout header headers_with_timeout = {"x-litellm-stream-timeout": "30.5"} result = LiteLLMProxyRequestSetup._get_stream_timeout_from_request(headers_with_timeout) assert result == 30.5 # Test without stream timeout header headers_without_timeout = {} result = LiteLLMProxyRequestSetup._get_stream_timeout_from_request(headers_without_timeout) assert result is None # Test with invalid header value (should raise ValueError when converting to float) headers_with_invalid = {"x-litellm-stream-timeout": "invalid"} with pytest.raises(ValueError): LiteLLMProxyRequestSetup._get_stream_timeout_from_request(headers_with_invalid) @pytest.mark.asyncio async def test_add_litellm_data_to_request_with_stream_timeout_header(self): """ Test that x-litellm-stream-timeout header gets processed and added to request data when calling add_litellm_data_to_request. """ from litellm.integrations.opentelemetry import UserAPIKeyAuth from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request # Create test data with a basic completion request test_data = { "model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}] } # Mock request with stream timeout header mock_request = MagicMock(spec=Request) mock_request.headers = {"x-litellm-stream-timeout": "45.0"} mock_request.url.path = "/v1/chat/completions" mock_request.method = "POST" mock_request.query_params = {} mock_request.client = None # Create a minimal mock with just the required attributes mock_user_api_key_dict = MagicMock() mock_user_api_key_dict.api_key = "test_api_key_hash" mock_user_api_key_dict.tpm_limit = None mock_user_api_key_dict.rpm_limit = None mock_user_api_key_dict.max_budget = None mock_user_api_key_dict.spend = 0 mock_user_api_key_dict.allowed_model_region = None mock_user_api_key_dict.key_alias = None mock_user_api_key_dict.user_id = None mock_user_api_key_dict.team_id = None mock_user_api_key_dict.metadata = {} # Prevent enterprise feature check mock_user_api_key_dict.team_metadata = None mock_user_api_key_dict.org_id = None mock_user_api_key_dict.team_alias = None mock_user_api_key_dict.end_user_id = None mock_user_api_key_dict.user_email = None mock_user_api_key_dict.request_route = None mock_user_api_key_dict.team_max_budget = None mock_user_api_key_dict.team_spend = None mock_user_api_key_dict.model_max_budget = None mock_user_api_key_dict.parent_otel_span = None mock_user_api_key_dict.team_model_aliases = None general_settings = {} mock_proxy_config = MagicMock() # Call the actual function that processes headers and adds data result_data = await add_litellm_data_to_request( data=test_data, request=mock_request, general_settings=general_settings, user_api_key_dict=mock_user_api_key_dict, version=None, proxy_config=mock_proxy_config, ) # Verify that stream_timeout was extracted from header and added to request data assert "stream_timeout" in result_data assert result_data["stream_timeout"] == 45.0 # Verify that the original test data is preserved assert result_data["model"] == "gpt-3.5-turbo" assert result_data["messages"] == [{"role": "user", "content": "Hello"}] def test_get_custom_headers_with_discount_info(self): """ Test that discount information is correctly extracted from logging object and included in response headers. """ from litellm.litellm_core_utils.litellm_logging import ( Logging as LiteLLMLoggingObj, ) # Create mock user API key dict mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) mock_user_api_key_dict.tpm_limit = None mock_user_api_key_dict.rpm_limit = None mock_user_api_key_dict.max_budget = None mock_user_api_key_dict.spend = 0 # Create logging object with cost breakdown including discount logging_obj = LiteLLMLoggingObj( model="vertex_ai/gemini-pro", messages=[{"role": "user", "content": "test"}], stream=False, call_type="completion", start_time=None, litellm_call_id="test-call-id", function_id="test-function-id", ) # Set cost breakdown with discount information logging_obj.set_cost_breakdown( input_cost=0.00005, output_cost=0.00005, total_cost=0.000095, # After 5% discount cost_for_built_in_tools_cost_usd_dollar=0.0, original_cost=0.0001, discount_percent=0.05, discount_amount=0.000005, ) # Call get_custom_headers with discount info headers = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=mock_user_api_key_dict, call_id="test-call-id", response_cost=0.000095, litellm_logging_obj=logging_obj, ) # Verify discount headers are present assert "x-litellm-response-cost" in headers assert float(headers["x-litellm-response-cost"]) == 0.000095 assert "x-litellm-response-cost-original" in headers assert float(headers["x-litellm-response-cost-original"]) == 0.0001 assert "x-litellm-response-cost-discount-amount" in headers assert float(headers["x-litellm-response-cost-discount-amount"]) == 0.000005 def test_get_custom_headers_without_discount_info(self): """ Test that when no discount is applied, discount headers are not included. """ from litellm.litellm_core_utils.litellm_logging import ( Logging as LiteLLMLoggingObj, ) # Create mock user API key dict mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) mock_user_api_key_dict.tpm_limit = None mock_user_api_key_dict.rpm_limit = None mock_user_api_key_dict.max_budget = None mock_user_api_key_dict.spend = 0 # Create logging object without discount logging_obj = LiteLLMLoggingObj( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "test"}], stream=False, call_type="completion", start_time=None, litellm_call_id="test-call-id", function_id="test-function-id", ) # Set cost breakdown without discount information logging_obj.set_cost_breakdown( input_cost=0.00005, output_cost=0.00005, total_cost=0.0001, cost_for_built_in_tools_cost_usd_dollar=0.0, ) # Call get_custom_headers headers = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=mock_user_api_key_dict, call_id="test-call-id", response_cost=0.0001, litellm_logging_obj=logging_obj, ) # Verify discount headers are NOT present assert "x-litellm-response-cost" in headers assert float(headers["x-litellm-response-cost"]) == 0.0001 # Discount headers should not be in the final dict assert "x-litellm-response-cost-original" not in headers assert "x-litellm-response-cost-discount-amount" not in headers def test_get_custom_headers_with_margin_info(self): """ Test that margin headers are included when margin is applied. """ from litellm.litellm_core_utils.litellm_logging import ( Logging as LiteLLMLoggingObj, ) # Create mock user API key dict mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) mock_user_api_key_dict.tpm_limit = None mock_user_api_key_dict.rpm_limit = None mock_user_api_key_dict.max_budget = None mock_user_api_key_dict.spend = 0 # Create logging object with margin logging_obj = LiteLLMLoggingObj( model="gpt-4", messages=[], stream=False, call_type="completion", start_time=None, litellm_call_id="test-call-id-margin", function_id="test-function", ) logging_obj.set_cost_breakdown( input_cost=0.00005, output_cost=0.00005, total_cost=0.00011, cost_for_built_in_tools_cost_usd_dollar=0.0, original_cost=0.0001, margin_percent=0.10, margin_total_amount=0.00001, ) headers = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=mock_user_api_key_dict, response_cost=0.00011, litellm_logging_obj=logging_obj, ) # Verify margin headers are present assert "x-litellm-response-cost" in headers assert float(headers["x-litellm-response-cost"]) == 0.00011 assert "x-litellm-response-cost-margin-amount" in headers assert float(headers["x-litellm-response-cost-margin-amount"]) == 0.00001 assert "x-litellm-response-cost-margin-percent" in headers assert float(headers["x-litellm-response-cost-margin-percent"]) == 0.10 def test_get_custom_headers_without_margin_info(self): """ Test that when no margin is applied, margin headers are not included. """ from litellm.litellm_core_utils.litellm_logging import ( Logging as LiteLLMLoggingObj, ) # Create mock user API key dict mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) mock_user_api_key_dict.tpm_limit = None mock_user_api_key_dict.rpm_limit = None mock_user_api_key_dict.max_budget = None mock_user_api_key_dict.spend = 0 # Create logging object without margin logging_obj = LiteLLMLoggingObj( model="gpt-4", messages=[], stream=False, call_type="completion", start_time=None, litellm_call_id="test-call-id-no-margin", function_id="test-function", ) logging_obj.set_cost_breakdown( input_cost=0.00005, output_cost=0.00005, total_cost=0.0001, cost_for_built_in_tools_cost_usd_dollar=0.0, ) headers = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=mock_user_api_key_dict, response_cost=0.0001, litellm_logging_obj=logging_obj, ) # Verify margin headers are not present assert "x-litellm-response-cost-margin-amount" not in headers assert "x-litellm-response-cost-margin-percent" not in headers def test_get_cost_breakdown_from_logging_obj_helper(self): """ Test the helper function that extracts cost breakdown information. """ from litellm.litellm_core_utils.litellm_logging import ( Logging as LiteLLMLoggingObj, ) # Test with discount info logging_obj = LiteLLMLoggingObj( model="vertex_ai/gemini-pro", messages=[{"role": "user", "content": "test"}], stream=False, call_type="completion", start_time=None, litellm_call_id="test-call-id", function_id="test-function-id", ) logging_obj.set_cost_breakdown( input_cost=0.00005, output_cost=0.00005, total_cost=0.000095, cost_for_built_in_tools_cost_usd_dollar=0.0, original_cost=0.0001, discount_percent=0.05, discount_amount=0.000005, ) original_cost, discount_amount, margin_total_amount, margin_percent = _get_cost_breakdown_from_logging_obj(logging_obj) assert original_cost == 0.0001 assert discount_amount == 0.000005 assert margin_total_amount is None assert margin_percent is None # Test with margin info logging_obj_with_margin = LiteLLMLoggingObj( model="gpt-4", messages=[{"role": "user", "content": "test"}], stream=False, call_type="completion", start_time=None, litellm_call_id="test-call-id-margin", function_id="test-function-id-margin", ) logging_obj_with_margin.set_cost_breakdown( input_cost=0.00005, output_cost=0.00005, total_cost=0.00011, cost_for_built_in_tools_cost_usd_dollar=0.0, original_cost=0.0001, margin_percent=0.10, margin_total_amount=0.00001, ) original_cost, discount_amount, margin_total_amount, margin_percent = _get_cost_breakdown_from_logging_obj(logging_obj_with_margin) assert original_cost == 0.0001 assert discount_amount is None assert margin_total_amount == 0.00001 assert margin_percent == 0.10 # Test with no discount or margin info logging_obj_no_discount = LiteLLMLoggingObj( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "test"}], stream=False, call_type="completion", start_time=None, litellm_call_id="test-call-id-2", function_id="test-function-id-2", ) logging_obj_no_discount.set_cost_breakdown( input_cost=0.00005, output_cost=0.00005, total_cost=0.0001, cost_for_built_in_tools_cost_usd_dollar=0.0, ) original_cost, discount_amount, margin_total_amount, margin_percent = _get_cost_breakdown_from_logging_obj(logging_obj_no_discount) assert original_cost is None assert discount_amount is None assert margin_total_amount is None assert margin_percent is None # Test with None logging object original_cost, discount_amount, margin_total_amount, margin_percent = _get_cost_breakdown_from_logging_obj(None) assert original_cost is None assert discount_amount is None assert margin_total_amount is None assert margin_percent is None def test_get_custom_headers_key_spend_includes_response_cost(self): """ Test that x-litellm-key-spend header includes the current request's response_cost. This ensures that the spend header reflects the updated spend including the current request, even though spend tracking updates happen asynchronously after the response. """ # Create mock user API key dict with initial spend mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) mock_user_api_key_dict.tpm_limit = None mock_user_api_key_dict.rpm_limit = None mock_user_api_key_dict.max_budget = None mock_user_api_key_dict.spend = 0.001 # Initial spend: $0.001 # Test case 1: response_cost is provided as float response_cost_1 = 0.0005 # Current request cost: $0.0005 headers_1 = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=mock_user_api_key_dict, call_id="test-call-id-1", response_cost=response_cost_1, ) assert "x-litellm-key-spend" in headers_1 expected_spend_1 = 0.001 + 0.0005 # Initial spend + current request cost assert float(headers_1["x-litellm-key-spend"]) == pytest.approx(expected_spend_1, abs=1e-10) assert float(headers_1["x-litellm-response-cost"]) == response_cost_1 # Test case 2: response_cost is provided as string response_cost_2 = "0.0003" # Current request cost as string headers_2 = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=mock_user_api_key_dict, call_id="test-call-id-2", response_cost=response_cost_2, ) assert "x-litellm-key-spend" in headers_2 expected_spend_2 = 0.001 + 0.0003 # Initial spend + current request cost assert float(headers_2["x-litellm-key-spend"]) == pytest.approx(expected_spend_2, abs=1e-10) # Test case 3: response_cost is None (should use original spend) headers_3 = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=mock_user_api_key_dict, call_id="test-call-id-3", response_cost=None, ) assert "x-litellm-key-spend" in headers_3 assert float(headers_3["x-litellm-key-spend"]) == 0.001 # Should use original spend # Test case 4: response_cost is 0 (should not change spend) headers_4 = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=mock_user_api_key_dict, call_id="test-call-id-4", response_cost=0.0, ) assert "x-litellm-key-spend" in headers_4 assert float(headers_4["x-litellm-key-spend"]) == 0.001 # Should remain unchanged for 0 cost # Test case 5: user_api_key_dict.spend is None (should default to 0.0) mock_user_api_key_dict.spend = None headers_5 = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=mock_user_api_key_dict, call_id="test-call-id-5", response_cost=0.0002, ) assert "x-litellm-key-spend" in headers_5 assert float(headers_5["x-litellm-key-spend"]) == 0.0002 # 0.0 + 0.0002 # Test case 6: response_cost is negative (should not be added, use original spend) mock_user_api_key_dict.spend = 0.001 headers_6 = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=mock_user_api_key_dict, call_id="test-call-id-6", response_cost=-0.0001, # Negative cost (should not be added) ) assert "x-litellm-key-spend" in headers_6 assert float(headers_6["x-litellm-key-spend"]) == 0.001 # Should use original spend # Test case 7: response_cost is invalid string (should fallback to original spend) headers_7 = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=mock_user_api_key_dict, call_id="test-call-id-7", response_cost="invalid", # Invalid string ) assert "x-litellm-key-spend" in headers_7 assert float(headers_7["x-litellm-key-spend"]) == 0.001 # Should use original spend on error @pytest.mark.asyncio class TestCommonRequestProcessingHelpers: async def consume_stream(self, streaming_response: StreamingResponse) -> list: content = [] async for chunk_bytes in streaming_response.body_iterator: content.append(chunk_bytes) return content @pytest.mark.parametrize( "event_line, expected_code", [ ( 'data: {"error": {"code": 400, "message": "bad request"}}', 400, ), # Valid integer code ( 'data: {"error": {"code": "401", "message": "unauthorized"}}', 401, ), # Valid string-integer code ( 'data: {"error": {"code": "invalid_code", "message": "error"}}', None, ), # Invalid string code ( 'data: {"error": {"code": 99, "message": "too low"}}', None, ), # Integer code too low ( 'data: {"error": {"code": 600, "message": "too high"}}', None, ), # Integer code too high ( 'data: {"id": "123", "content": "hello"}', None, ), # Non-error SSE event ("data: [DONE]", None), # SSE [DONE] event ("data: ", None), # SSE empty data event ( 'data: {"error": {"code": 400', None, ), # Malformed JSON ("id: 123", None), # Non-SSE event line ( 'data: {"error": {"message": "some error"}}', None, ), # Error event without 'code' field ( 'data: {"error": {"code": null, "message": "code is null"}}', None, ), # Error with null code ], ) async def test_parse_event_data_for_error(self, event_line, expected_code): assert await _parse_event_data_for_error(event_line) == expected_code async def test_create_streaming_response_first_chunk_is_error(self): """ Test that when the first chunk is an error, a JSON error response is returned instead of an SSE streaming response """ async def mock_generator(): yield 'data: {"error": {"code": 403, "message": "forbidden"}}\n\n' yield 'data: {"content": "more data"}\n\n' yield "data: [DONE]\n\n" response = await create_response( mock_generator(), "text/event-stream", {} ) # Should return JSONResponse instead of StreamingResponse assert isinstance(response, JSONResponse) assert response.status_code == status.HTTP_403_FORBIDDEN # Verify the response is in standard JSON error format import json body = json.loads(response.body.decode()) assert "error" in body assert body["error"]["code"] == 403 assert body["error"]["message"] == "forbidden" async def test_create_streaming_response_first_chunk_not_error(self): async def mock_generator(): yield 'data: {"content": "first part"}\n\n' yield 'data: {"content": "second part"}\n\n' yield "data: [DONE]\n\n" response = await create_response( mock_generator(), "text/event-stream", {} ) assert response.status_code == status.HTTP_200_OK content = await self.consume_stream(response) assert content == [ 'data: {"content": "first part"}\n\n', 'data: {"content": "second part"}\n\n', "data: [DONE]\n\n", ] async def test_create_streaming_response_empty_generator(self): async def mock_generator(): if False: # Never yields yield # Implicitly raises StopAsyncIteration response = await create_response( mock_generator(), "text/event-stream", {} ) assert response.status_code == status.HTTP_200_OK content = await self.consume_stream(response) assert content == [] async def test_create_streaming_response_generator_raises_stop_async_iteration_immediately( self, ): mock_gen = AsyncMock() mock_gen.__anext__.side_effect = StopAsyncIteration response = await create_response(mock_gen, "text/event-stream", {}) assert response.status_code == status.HTTP_200_OK content = await self.consume_stream(response) assert content == [] async def test_create_streaming_response_generator_raises_unexpected_exception( self, ): mock_gen = AsyncMock() mock_gen.__anext__.side_effect = ValueError("Test error from generator") response = await create_response(mock_gen, "text/event-stream", {}) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR content = await self.consume_stream(response) expected_error_data = { "error": { "message": "Error processing stream start", "code": status.HTTP_500_INTERNAL_SERVER_ERROR, } } assert len(content) == 2 # Use json.dumps to match the formatting in create_streaming_response's exception handler import json assert content[0] == f"data: {json.dumps(expected_error_data)}\n\n" assert content[1] == "data: [DONE]\n\n" async def test_create_streaming_response_first_chunk_error_string_code(self): """ Test that when the first chunk contains a string error code, a JSON error response is returned """ async def mock_generator(): yield 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n' yield "data: [DONE]\n\n" response = await create_response( mock_generator(), "text/event-stream", {} ) assert isinstance(response, JSONResponse) assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS # Verify the response is in standard JSON error format import json body = json.loads(response.body.decode()) assert "error" in body assert body["error"]["code"] == "429" assert body["error"]["message"] == "too many requests" async def test_create_streaming_response_custom_headers(self): async def mock_generator(): yield 'data: {"content": "data"}\n\n' yield "data: [DONE]\n\n" custom_headers = {"X-Custom-Header": "TestValue"} response = await create_response( mock_generator(), "text/event-stream", custom_headers ) assert response.headers["x-custom-header"] == "TestValue" async def test_create_streaming_response_non_default_status_code(self): async def mock_generator(): yield 'data: {"content": "data"}\n\n' yield "data: [DONE]\n\n" response = await create_response( mock_generator(), "text/event-stream", {}, default_status_code=status.HTTP_201_CREATED, ) assert response.status_code == status.HTTP_201_CREATED content = await self.consume_stream(response) assert content == [ 'data: {"content": "data"}\n\n', "data: [DONE]\n\n", ] async def test_create_streaming_response_first_chunk_is_done(self): async def mock_generator(): yield "data: [DONE]\n\n" response = await create_response( mock_generator(), "text/event-stream", {} ) assert response.status_code == status.HTTP_200_OK # Default status content = await self.consume_stream(response) assert content == ["data: [DONE]\n\n"] async def test_create_streaming_response_first_chunk_is_empty_data(self): async def mock_generator(): yield "data: \n\n" yield 'data: {"content": "actual data"}\n\n' yield "data: [DONE]\n\n" response = await create_response( mock_generator(), "text/event-stream", {} ) assert response.status_code == status.HTTP_200_OK # Default status content = await self.consume_stream(response) assert content == [ "data: \n\n", 'data: {"content": "actual data"}\n\n', "data: [DONE]\n\n", ] async def test_create_streaming_response_all_chunks_have_dd_trace(self): """Test that all stream chunks are wrapped with dd trace at the streaming generator level""" import json from unittest.mock import patch # Create a mock tracer mock_tracer = MagicMock() mock_span = MagicMock() mock_tracer.trace.return_value.__enter__.return_value = mock_span mock_tracer.trace.return_value.__exit__.return_value = None # Mock generator with multiple chunks async def mock_generator(): yield 'data: {"content": "chunk 1"}\n\n' yield 'data: {"content": "chunk 2"}\n\n' yield 'data: {"content": "chunk 3"}\n\n' yield "data: [DONE]\n\n" # Patch the tracer in the common_request_processing module with patch("litellm.proxy.common_request_processing.tracer", mock_tracer): response = await create_response( mock_generator(), "text/event-stream", {} ) assert response.status_code == 200 # Consume the stream to trigger the tracer calls content = await self.consume_stream(response) # Verify all chunks are present assert len(content) == 4 assert content[0] == 'data: {"content": "chunk 1"}\n\n' assert content[1] == 'data: {"content": "chunk 2"}\n\n' assert content[2] == 'data: {"content": "chunk 3"}\n\n' assert content[3] == "data: [DONE]\n\n" # Verify that tracer.trace was called for each chunk (4 chunks total) assert mock_tracer.trace.call_count == 4 # Verify that each call was made with the correct operation name expected_calls = [ (("streaming.chunk.yield",), {}), (("streaming.chunk.yield",), {}), (("streaming.chunk.yield",), {}), (("streaming.chunk.yield",), {}), ] actual_calls = mock_tracer.trace.call_args_list assert len(actual_calls) == 4 for i, call in enumerate(actual_calls): args, kwargs = call assert ( args[0] == "streaming.chunk.yield" ), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}" async def test_create_streaming_response_dd_trace_with_error_chunk(self): """ Test that when the first chunk contains an error, JSONResponse is returned and tracing is not triggered (since it's not a streaming response) """ from unittest.mock import patch # Create a mock tracer mock_tracer = MagicMock() mock_span = MagicMock() mock_tracer.trace.return_value.__enter__.return_value = mock_span mock_tracer.trace.return_value.__exit__.return_value = None # Mock generator with error in first chunk async def mock_generator(): yield 'data: {"error": {"code": 400, "message": "bad request"}}\n\n' yield 'data: {"content": "chunk after error"}\n\n' yield "data: [DONE]\n\n" # Patch the tracer in the common_request_processing module with patch("litellm.proxy.common_request_processing.tracer", mock_tracer): response = await create_response( mock_generator(), "text/event-stream", {} ) # Should return JSONResponse instead of StreamingResponse assert isinstance(response, JSONResponse) assert response.status_code == 400 # Verify the response is in standard JSON error format import json body = json.loads(response.body.decode()) assert "error" in body assert body["error"]["code"] == 400 assert body["error"]["message"] == "bad request" # Since JSONResponse is returned instead of StreamingResponse, streaming tracing should not be triggered # tracer.trace should not be called assert mock_tracer.trace.call_count == 0 class TestExtractErrorFromSSEChunk: """Tests for _extract_error_from_sse_chunk function""" def test_extract_error_from_sse_chunk_with_valid_error(self): """Test extracting error information from a standard SSE chunk""" chunk = 'data: {"error": {"code": 403, "message": "forbidden", "type": "auth_error", "param": "api_key"}}\n\n' error = _extract_error_from_sse_chunk(chunk) assert error["code"] == 403 assert error["message"] == "forbidden" assert error["type"] == "auth_error" assert error["param"] == "api_key" def test_extract_error_from_sse_chunk_with_string_code(self): """Test error code as string type""" chunk = 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n' error = _extract_error_from_sse_chunk(chunk) assert error["code"] == "429" assert error["message"] == "too many requests" def test_extract_error_from_sse_chunk_with_bytes(self): """Test input as bytes type""" chunk = b'data: {"error": {"code": 500, "message": "internal error"}}\n\n' error = _extract_error_from_sse_chunk(chunk) assert error["code"] == 500 assert error["message"] == "internal error" def test_extract_error_from_sse_chunk_with_done(self): """Test [DONE] marker should return default error""" chunk = "data: [DONE]\n\n" error = _extract_error_from_sse_chunk(chunk) assert error["message"] == "Unknown error" assert error["type"] == "internal_server_error" assert error["code"] == "500" assert error["param"] is None def test_extract_error_from_sse_chunk_without_error_field(self): """Test missing error field should return default error""" chunk = 'data: {"content": "some content"}\n\n' error = _extract_error_from_sse_chunk(chunk) assert error["message"] == "Unknown error" assert error["type"] == "internal_server_error" assert error["code"] == "500" def test_extract_error_from_sse_chunk_with_invalid_json(self): """Test invalid JSON should return default error""" chunk = 'data: {invalid json}\n\n' error = _extract_error_from_sse_chunk(chunk) assert error["message"] == "Unknown error" assert error["type"] == "internal_server_error" assert error["code"] == "500" def test_extract_error_from_sse_chunk_without_data_prefix(self): """Test missing 'data:' prefix should return default error""" chunk = '{"error": {"code": 400, "message": "bad request"}}\n\n' error = _extract_error_from_sse_chunk(chunk) assert error["message"] == "Unknown error" assert error["type"] == "internal_server_error" assert error["code"] == "500" def test_extract_error_from_sse_chunk_with_empty_string(self): """Test empty string should return default error""" chunk = "" error = _extract_error_from_sse_chunk(chunk) assert error["message"] == "Unknown error" assert error["type"] == "internal_server_error" assert error["code"] == "500" def test_extract_error_from_sse_chunk_with_minimal_error(self): """Test minimal error object""" chunk = 'data: {"error": {"message": "error occurred"}}\n\n' error = _extract_error_from_sse_chunk(chunk) assert error["message"] == "error occurred" # Other fields should be obtained from the original error object (if exists)