mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 14:48:44 +00:00
2453936a82
* Add support for websocket via codex * Add model alias and creds support * fix: skip cost tracking for WS session wrapper call types The @client decorator on _aresponses_websocket fires async_success_handler with result=None after the session ends. This triggered cost tracking errors because standard_logging_object is never built for None results. Per-turn costs are correctly tracked by individual litellm.aresponses calls inside the session. The outer session-level logging obj should not attempt cost tracking. Fix: skip _aresponses_websocket and _arealtime call types in deployment_callback_on_success, RouterBudgetLimiting.async_log_success_event, and _PROXY_track_cost_callback. * fix: address Greptile review comments Fix JSON injection: use json.dumps instead of f-string interpolation for model name in WS body. Add 30s timeout for first WS frame to prevent unbounded connection resource tie-up. Restore per-event model override in streaming_iterator; fall back to connection-level model when event omits it. Strengthen regression test: inject alias into kwargs via _update_kwargs_with_deployment mock so the test would fail on un-fixed code. * fix: handle nested response.create format in first-frame model extraction When ?model= is omitted, the first WS frame can carry the model in either flat format (first_event["model"]) or nested format (first_event["response"]["model"]). The flat-only check would silently reject clients using the nested wire format. Mirrors the same two-format logic in _build_base_call_kwargs. * fix: don't force connection-level custom_llm_provider on per-event model overrides If a client sends a different model per response.create turn, litellm needs to re-resolve the provider from that model string. Forcing the connection-level custom_llm_provider would silently route the request to the wrong backend. Only inject custom_llm_provider when the per-event model matches the connection-level model. * refactor: extract WS model extraction into testable function Pull the flat/nested model extraction into _extract_model_from_first_ws_event so tests import and exercise the real function rather than a copy. * fix: compare providers not full model strings in _inject_credentials The model == self.model guard was too strict: same-provider model variants (e.g., vertex_ai/gemini-2.0 -> vertex_ai/gemini-1.5 on one connection) would lose custom_llm_provider, breaking routing when a custom api_base is in use. Compare the provider extracted by get_llm_provider instead, so same-provider variants still inherit the connection-level provider while cross-provider overrides let litellm re-resolve. * style: black formatting * refactor: extract first-frame model resolution to fix PLR0915 (too many statements) * Fix responses WebSocket first-frame validation * fix: classify WS first-frame read errors and clarify cost-skip log Distinguish client disconnects from server errors when reading the responses WebSocket first frame, make the cost-tracking skip log message accurate for session wrappers (which do carry a model), and resolve the connection-level provider once per session instead of on every response.create event. * test: cover WS first-frame read errors and same-provider credential injection Adds regression tests for the still-uncovered responses WebSocket paths: the timeout, invalid-JSON and missing-model branches of _read_ws_model_from_first_frame, plus the provider comparison in ManagedResponsesWebSocketHandler._same_provider and _inject_credentials (same-provider model variants keep the connection provider; cross-provider models re-resolve). * fix(responses-ws): fall back to explicit custom_llm_provider when connection model is unresolvable When a WebSocket session is opened with a custom deployment alias that litellm cannot resolve to a provider, _connection_provider was None, so _same_provider returned False for every resolvable per-event model and the connection-level custom_llm_provider was dropped. Use the explicitly-set custom_llm_provider as the connection provider in that case so same-provider per-event models still inherit it while genuinely cross-provider models continue to re-resolve. --------- Co-authored-by: Cursor Agent <cursoragent@cursor.com> Co-authored-by: mateo-berri <277851410+mateo-berri@users.noreply.github.com>
714 lines
25 KiB
Python
714 lines
25 KiB
Python
"""
|
|
Test for response_api_endpoints/endpoints.py
|
|
"""
|
|
|
|
import unittest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from litellm.proxy.proxy_server import app
|
|
|
|
|
|
class TestResponsesAPIEndpoints(unittest.TestCase):
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.proxy.proxy_server.llm_router")
|
|
@patch("litellm.proxy.proxy_server.user_api_key_auth")
|
|
async def test_openai_v1_responses_route(self, mock_auth, mock_router):
|
|
"""
|
|
Test that /openai/v1/responses endpoint is correctly registered and accessible.
|
|
"""
|
|
mock_auth.return_value = MagicMock(
|
|
token="test_token",
|
|
user_id="test_user",
|
|
team_id=None,
|
|
)
|
|
|
|
mock_router.aresponses = AsyncMock(
|
|
return_value={
|
|
"id": "resp_abc123",
|
|
"object": "realtime.response",
|
|
"status": "completed",
|
|
"output": [
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"content": [{"type": "text", "text": "Test response"}],
|
|
}
|
|
],
|
|
}
|
|
)
|
|
|
|
client = TestClient(app)
|
|
|
|
test_data = {"model": "gpt-4o", "input": "Tell me about AI"}
|
|
|
|
response = client.post(
|
|
"/openai/v1/responses",
|
|
json=test_data,
|
|
headers={"Authorization": "Bearer sk-1234"},
|
|
)
|
|
|
|
assert response.status_code in [200, 401, 500]
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.proxy.proxy_server.llm_router")
|
|
@patch("litellm.proxy.proxy_server.user_api_key_auth")
|
|
async def test_cursor_chat_completions_route(self, mock_auth, mock_router):
|
|
"""
|
|
Test that /cursor/chat/completions endpoint:
|
|
1. Accepts Responses API input format
|
|
2. Returns chat completions format response
|
|
3. Transforms streaming responses correctly
|
|
"""
|
|
from litellm.types.llms.openai import ResponsesAPIResponse
|
|
from litellm.types.utils import ResponseOutputMessage, ResponseOutputText
|
|
|
|
mock_auth.return_value = MagicMock(
|
|
token="test_token",
|
|
user_id="test_user",
|
|
team_id=None,
|
|
)
|
|
|
|
# Mock a Responses API response
|
|
mock_responses_response = ResponsesAPIResponse(
|
|
id="resp_cursor123",
|
|
created_at=1234567890,
|
|
model="gpt-4o",
|
|
object="response",
|
|
output=[
|
|
ResponseOutputMessage(
|
|
type="message",
|
|
role="assistant",
|
|
content=[
|
|
ResponseOutputText(
|
|
type="output_text", text="Hello from Cursor!"
|
|
)
|
|
],
|
|
)
|
|
],
|
|
)
|
|
|
|
mock_router.aresponses = AsyncMock(return_value=mock_responses_response)
|
|
|
|
client = TestClient(app)
|
|
|
|
# Test with Responses API input format (what Cursor sends)
|
|
test_data = {
|
|
"model": "gpt-4o",
|
|
"input": [{"role": "user", "content": "Hello"}],
|
|
}
|
|
|
|
response = client.post(
|
|
"/cursor/chat/completions",
|
|
json=test_data,
|
|
headers={"Authorization": "Bearer sk-1234"},
|
|
)
|
|
|
|
# Should return 200 (or 401/500 if auth fails)
|
|
assert response.status_code in [200, 401, 500]
|
|
|
|
# If successful, verify it returns chat completions format
|
|
if response.status_code == 200:
|
|
response_data = response.json()
|
|
# Should have chat completion structure
|
|
assert "choices" in response_data or "id" in response_data
|
|
# Should not have Responses API structure
|
|
assert "output" not in response_data or "status" not in response_data
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("litellm.proxy.proxy_server.llm_router")
|
|
@patch("litellm.proxy.proxy_server.user_api_key_auth")
|
|
async def test_responses_api_key_spend_header_includes_response_cost(
|
|
self, mock_auth, mock_router
|
|
):
|
|
"""
|
|
Test that x-litellm-key-spend header includes the current request's response_cost
|
|
for /v1/responses endpoint.
|
|
|
|
This ensures the spend header reflects updated spend including the current request,
|
|
even though spend tracking updates happen asynchronously after the response.
|
|
"""
|
|
from litellm.types.llms.openai import ResponsesAPIResponse
|
|
from litellm.types.utils import ResponseOutputMessage, ResponseOutputText
|
|
|
|
# Create mock user API key with initial spend
|
|
mock_user_api_key_dict = MagicMock()
|
|
mock_user_api_key_dict.token = "test_token"
|
|
mock_user_api_key_dict.user_id = "test_user"
|
|
mock_user_api_key_dict.team_id = None
|
|
mock_user_api_key_dict.spend = 0.001 # Initial spend: $0.001
|
|
mock_user_api_key_dict.tpm_limit = None
|
|
mock_user_api_key_dict.rpm_limit = None
|
|
mock_user_api_key_dict.max_budget = None
|
|
mock_user_api_key_dict.allowed_model_region = None
|
|
mock_user_api_key_dict.api_key = "sk-test-key"
|
|
mock_user_api_key_dict.metadata = {}
|
|
|
|
mock_auth.return_value = mock_user_api_key_dict
|
|
|
|
# Mock response with hidden_params containing response_cost
|
|
mock_response = ResponsesAPIResponse(
|
|
id="resp_test123",
|
|
created_at=1234567890,
|
|
model="gpt-4o",
|
|
object="response",
|
|
output=[
|
|
ResponseOutputMessage(
|
|
type="message",
|
|
role="assistant",
|
|
content=[
|
|
ResponseOutputText(type="output_text", text="Test response")
|
|
],
|
|
)
|
|
],
|
|
)
|
|
|
|
# Add hidden_params with response_cost to the mock response
|
|
mock_response._hidden_params = {
|
|
"response_cost": 0.0005, # Current request cost: $0.0005
|
|
"model_id": "test-model-id",
|
|
}
|
|
|
|
mock_router.aresponses = AsyncMock(return_value=mock_response)
|
|
|
|
client = TestClient(app)
|
|
|
|
test_data = {"model": "gpt-4o", "input": "Tell me about AI"}
|
|
|
|
response = client.post(
|
|
"/v1/responses",
|
|
json=test_data,
|
|
headers={"Authorization": "Bearer sk-test-key"},
|
|
)
|
|
|
|
# Verify the response was successful
|
|
assert response.status_code == 200
|
|
|
|
# Verify x-litellm-key-spend header includes current request cost
|
|
assert "x-litellm-key-spend" in response.headers
|
|
key_spend_value = float(response.headers["x-litellm-key-spend"])
|
|
expected_spend = 0.001 + 0.0005 # Initial spend + current request cost
|
|
assert key_spend_value == pytest.approx(expected_spend, abs=1e-10)
|
|
|
|
# Verify x-litellm-response-cost header is present
|
|
assert "x-litellm-response-cost" in response.headers
|
|
response_cost_value = float(response.headers["x-litellm-response-cost"])
|
|
assert response_cost_value == pytest.approx(0.0005, abs=1e-10)
|
|
|
|
|
|
import json
|
|
|
|
|
|
class TestManagedResponsesWSFirstMessage:
|
|
@pytest.mark.asyncio
|
|
async def test_first_message_processed_before_loop(self):
|
|
"""
|
|
ManagedResponsesWebSocketHandler must process first_message before
|
|
entering its receive loop. Regression for clients that connect without
|
|
?model= (e.g. Codex) and send model inside the first response.create event.
|
|
"""
|
|
from litellm.responses.streaming_iterator import ManagedResponsesWebSocketHandler
|
|
|
|
first = json.dumps(
|
|
{
|
|
"type": "response.create",
|
|
"model": "gpt-4o-mini",
|
|
"store": False,
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [{"type": "input_text", "text": "hi"}],
|
|
}
|
|
],
|
|
}
|
|
)
|
|
|
|
ws = MagicMock()
|
|
ws.receive_text = AsyncMock(side_effect=Exception("disconnect"))
|
|
ws.send_text = AsyncMock()
|
|
|
|
processed: list = []
|
|
|
|
async def fake_process(msg: str) -> None:
|
|
processed.append(msg)
|
|
|
|
handler = ManagedResponsesWebSocketHandler(
|
|
websocket=ws,
|
|
model="gpt-4o-mini",
|
|
logging_obj=MagicMock(),
|
|
first_message=first,
|
|
)
|
|
handler._process_response_create = fake_process # type: ignore[method-assign]
|
|
|
|
await handler.run()
|
|
|
|
assert processed == [first]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_first_message_falls_through_to_loop(self):
|
|
"""When first_message is None, run() goes straight to receive_text()."""
|
|
from litellm.responses.streaming_iterator import ManagedResponsesWebSocketHandler
|
|
|
|
subsequent = json.dumps({"type": "response.create", "model": "gpt-4o-mini"})
|
|
|
|
ws = MagicMock()
|
|
ws.receive_text = AsyncMock(side_effect=[subsequent, Exception("disconnect")])
|
|
ws.send_text = AsyncMock()
|
|
|
|
processed: list = []
|
|
|
|
async def fake_process(msg: str) -> None:
|
|
processed.append(msg)
|
|
|
|
handler = ManagedResponsesWebSocketHandler(
|
|
websocket=ws,
|
|
model="gpt-4o-mini",
|
|
logging_obj=MagicMock(),
|
|
first_message=None,
|
|
)
|
|
handler._process_response_create = fake_process # type: ignore[method-assign]
|
|
|
|
await handler.run()
|
|
|
|
assert processed == [subsequent]
|
|
|
|
|
|
class TestResponsesWSStreamingFirstMessage:
|
|
@pytest.mark.asyncio
|
|
async def test_client_to_backend_replays_first_message(self):
|
|
"""
|
|
ResponsesWebSocketStreaming.client_to_backend must send first_message to
|
|
the backend before entering the receive loop.
|
|
"""
|
|
from litellm.responses.streaming_iterator import ResponsesWebSocketStreaming
|
|
|
|
first = json.dumps({"type": "response.create", "model": "gpt-4o-mini", "input": []})
|
|
|
|
ws = MagicMock()
|
|
ws.receive_text = AsyncMock(side_effect=Exception("disconnect"))
|
|
|
|
backend_ws = MagicMock()
|
|
backend_ws.send = AsyncMock()
|
|
|
|
streaming = ResponsesWebSocketStreaming(
|
|
websocket=ws,
|
|
backend_ws=backend_ws,
|
|
logging_obj=MagicMock(),
|
|
first_message=first,
|
|
)
|
|
|
|
await streaming.client_to_backend()
|
|
|
|
backend_ws.send.assert_awaited_once_with(first)
|
|
|
|
|
|
class TestWSSessionCostTracking:
|
|
@pytest.mark.asyncio
|
|
async def test_router_budget_limiter_skips_aresponses_websocket_call_type(self):
|
|
"""
|
|
RouterBudgetLimiting.async_log_success_event must not raise when
|
|
call_type='_aresponses_websocket', even when standard_logging_object is None.
|
|
Per-turn costs are tracked by individual aresponses calls inside the session;
|
|
the outer session wrapper fires with result=None.
|
|
"""
|
|
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
|
|
|
limiter = RouterBudgetLimiting.__new__(RouterBudgetLimiting)
|
|
kwargs = {
|
|
"call_type": "_aresponses_websocket",
|
|
"standard_logging_object": None,
|
|
"litellm_params": {"custom_llm_provider": "vertex_ai"},
|
|
}
|
|
await limiter.async_log_success_event(
|
|
kwargs=kwargs,
|
|
response_obj=None,
|
|
start_time=None,
|
|
end_time=None,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_router_budget_limiter_skips_arealtime_call_type(self):
|
|
"""Same guard applies to _arealtime WS session wrappers."""
|
|
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
|
|
|
limiter = RouterBudgetLimiting.__new__(RouterBudgetLimiting)
|
|
kwargs = {
|
|
"call_type": "_arealtime",
|
|
"standard_logging_object": None,
|
|
"litellm_params": {"custom_llm_provider": "openai"},
|
|
}
|
|
await limiter.async_log_success_event(
|
|
kwargs=kwargs,
|
|
response_obj=None,
|
|
start_time=None,
|
|
end_time=None,
|
|
)
|
|
|
|
|
|
class TestWSModelExtraction:
|
|
"""Test _extract_model_from_first_ws_event for flat and nested frame formats."""
|
|
|
|
def test_flat_format_extracts_model(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_extract_model_from_first_ws_event,
|
|
)
|
|
event = {"type": "response.create", "model": "gpt-4o", "input": "hello"}
|
|
assert _extract_model_from_first_ws_event(event) == "gpt-4o"
|
|
|
|
def test_nested_format_extracts_model(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_extract_model_from_first_ws_event,
|
|
)
|
|
event = {"type": "response.create", "response": {"model": "gpt-4o", "input": "hello"}}
|
|
assert _extract_model_from_first_ws_event(event) == "gpt-4o"
|
|
|
|
def test_nested_format_takes_precedence_over_flat(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_extract_model_from_first_ws_event,
|
|
)
|
|
event = {
|
|
"type": "response.create",
|
|
"model": "flat-model",
|
|
"response": {"model": "nested-model"},
|
|
}
|
|
assert _extract_model_from_first_ws_event(event) == "nested-model"
|
|
|
|
def test_no_model_returns_none(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_extract_model_from_first_ws_event,
|
|
)
|
|
event = {"type": "response.create", "input": "hello"}
|
|
assert _extract_model_from_first_ws_event(event) is None
|
|
|
|
def test_non_object_returns_none(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_extract_model_from_first_ws_event,
|
|
)
|
|
|
|
assert _extract_model_from_first_ws_event([]) is None
|
|
|
|
|
|
class TestResponsesWSFirstFrameValidation:
|
|
@pytest.mark.asyncio
|
|
async def test_rejects_non_response_create_first_frame(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_read_ws_model_from_first_frame,
|
|
)
|
|
|
|
ws = MagicMock()
|
|
ws.receive_text = AsyncMock(
|
|
return_value=json.dumps({"type": "session.update", "model": "gpt-4o"})
|
|
)
|
|
ws.send_text = AsyncMock()
|
|
ws.close = AsyncMock()
|
|
|
|
result = await _read_ws_model_from_first_frame(ws)
|
|
|
|
assert result is None
|
|
ws.send_text.assert_awaited_once()
|
|
ws.close.assert_awaited_once_with(code=1008, reason="Invalid first message")
|
|
error_payload = json.loads(ws.send_text.await_args.args[0])
|
|
assert (
|
|
error_payload["error"]["message"]
|
|
== "First message must be a response.create JSON object."
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rejects_non_object_json_first_frame(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_read_ws_model_from_first_frame,
|
|
)
|
|
|
|
ws = MagicMock()
|
|
ws.receive_text = AsyncMock(return_value=json.dumps(["gpt-4o"]))
|
|
ws.send_text = AsyncMock()
|
|
ws.close = AsyncMock()
|
|
|
|
result = await _read_ws_model_from_first_frame(ws)
|
|
|
|
assert result is None
|
|
ws.send_text.assert_awaited_once()
|
|
ws.close.assert_awaited_once_with(code=1008, reason="Invalid first message")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_client_disconnect_first_frame_does_not_close(self):
|
|
from fastapi import WebSocketDisconnect
|
|
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_read_ws_model_from_first_frame,
|
|
)
|
|
|
|
ws = MagicMock()
|
|
ws.receive_text = AsyncMock(side_effect=WebSocketDisconnect(code=1006))
|
|
ws.send_text = AsyncMock()
|
|
ws.close = AsyncMock()
|
|
|
|
result = await _read_ws_model_from_first_frame(ws)
|
|
|
|
assert result is None
|
|
ws.close.assert_not_awaited()
|
|
ws.send_text.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_server_error_first_frame_closes_with_internal_error(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_read_ws_model_from_first_frame,
|
|
)
|
|
|
|
ws = MagicMock()
|
|
ws.receive_text = AsyncMock(side_effect=RuntimeError("boom"))
|
|
ws.send_text = AsyncMock()
|
|
ws.close = AsyncMock()
|
|
|
|
result = await _read_ws_model_from_first_frame(ws)
|
|
|
|
assert result is None
|
|
ws.close.assert_awaited_once_with(code=1011, reason="Internal server error")
|
|
|
|
|
|
class TestResponsesWSFirstFrameModelAuth:
|
|
@pytest.mark.asyncio
|
|
async def test_endpoint_enforces_auth_after_model_from_first_frame(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
responses_websocket_endpoint,
|
|
)
|
|
|
|
ws = MagicMock()
|
|
ws.headers = {}
|
|
ws.query_params = {}
|
|
ws.scope = {"headers": []}
|
|
ws.url = "ws://testserver/v1/responses"
|
|
ws.accept = AsyncMock()
|
|
ws.receive_text = AsyncMock(
|
|
return_value=json.dumps(
|
|
{"type": "response.create", "model": "gpt-4o-mini", "input": []}
|
|
)
|
|
)
|
|
ws.close = AsyncMock()
|
|
|
|
processor = MagicMock()
|
|
processor.common_processing_pre_call_logic = AsyncMock(
|
|
return_value=({"model": "gpt-4o-mini"}, MagicMock())
|
|
)
|
|
|
|
async def fake_llm_call():
|
|
return None
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.response_api_endpoints.endpoints._enforce_responses_ws_first_frame_model_auth",
|
|
new_callable=AsyncMock,
|
|
) as mock_model_auth,
|
|
patch(
|
|
"litellm.proxy.response_api_endpoints.endpoints.ProxyBaseLLMRequestProcessing",
|
|
return_value=processor,
|
|
),
|
|
patch(
|
|
"litellm.proxy.route_llm_request.route_request",
|
|
new_callable=AsyncMock,
|
|
return_value=fake_llm_call(),
|
|
),
|
|
):
|
|
await responses_websocket_endpoint(
|
|
websocket=ws,
|
|
model=None,
|
|
user_api_key_dict=MagicMock(),
|
|
)
|
|
|
|
mock_model_auth.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reruns_model_auth_for_first_frame_model(self):
|
|
from starlette.requests import Request
|
|
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_enforce_responses_ws_first_frame_model_auth,
|
|
)
|
|
|
|
request = Request(
|
|
{"type": "http", "method": "POST", "path": "/v1/responses", "headers": []}
|
|
)
|
|
user_api_key_dict = MagicMock()
|
|
llm_router = MagicMock()
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.auth.user_api_key_auth._enforce_key_and_fallback_model_access",
|
|
new_callable=AsyncMock,
|
|
) as mock_key_check,
|
|
patch(
|
|
"litellm.proxy.auth.user_api_key_auth._run_centralized_common_checks",
|
|
new_callable=AsyncMock,
|
|
) as mock_common_checks,
|
|
patch(
|
|
"litellm.proxy.proxy_server.llm_model_list",
|
|
[],
|
|
),
|
|
patch("litellm.proxy.proxy_server.master_key", "sk-test"),
|
|
patch("litellm.proxy.proxy_server.user_custom_auth", None),
|
|
patch("litellm.proxy.proxy_server.general_settings", {}),
|
|
):
|
|
await _enforce_responses_ws_first_frame_model_auth(
|
|
request=request,
|
|
model="gpt-4o-mini",
|
|
user_api_key_dict=user_api_key_dict,
|
|
llm_router=llm_router,
|
|
)
|
|
|
|
mock_key_check.assert_awaited_once_with(
|
|
valid_token=user_api_key_dict,
|
|
request_data={"model": "gpt-4o-mini"},
|
|
route="/v1/responses",
|
|
request=request,
|
|
llm_model_list=[],
|
|
llm_router=llm_router,
|
|
)
|
|
mock_common_checks.assert_awaited_once_with(
|
|
user_api_key_auth_obj=user_api_key_dict,
|
|
request=request,
|
|
request_data={"model": "gpt-4o-mini"},
|
|
route="/v1/responses",
|
|
)
|
|
|
|
|
|
class TestReadWSModelFromFirstFrameErrors:
|
|
@pytest.mark.asyncio
|
|
async def test_timeout_closes_without_error_frame(self):
|
|
import asyncio
|
|
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_read_ws_model_from_first_frame,
|
|
)
|
|
|
|
ws = MagicMock()
|
|
ws.receive_text = AsyncMock(side_effect=asyncio.TimeoutError())
|
|
ws.send_text = AsyncMock()
|
|
ws.close = AsyncMock()
|
|
|
|
result = await _read_ws_model_from_first_frame(ws)
|
|
|
|
assert result is None
|
|
ws.send_text.assert_not_awaited()
|
|
ws.close.assert_awaited_once_with(
|
|
code=1008, reason="Timed out waiting for first message"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_json_sends_error_and_closes(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_read_ws_model_from_first_frame,
|
|
)
|
|
|
|
ws = MagicMock()
|
|
ws.receive_text = AsyncMock(return_value="this is not json")
|
|
ws.send_text = AsyncMock()
|
|
ws.close = AsyncMock()
|
|
|
|
result = await _read_ws_model_from_first_frame(ws)
|
|
|
|
assert result is None
|
|
payload = json.loads(ws.send_text.await_args.args[0])
|
|
assert payload["error"]["message"] == "First message is not valid JSON."
|
|
ws.close.assert_awaited_once_with(
|
|
code=1008, reason="Invalid JSON in first message"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_model_sends_error_and_closes(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_read_ws_model_from_first_frame,
|
|
)
|
|
|
|
ws = MagicMock()
|
|
ws.receive_text = AsyncMock(
|
|
return_value=json.dumps({"type": "response.create", "input": []})
|
|
)
|
|
ws.send_text = AsyncMock()
|
|
ws.close = AsyncMock()
|
|
|
|
result = await _read_ws_model_from_first_frame(ws)
|
|
|
|
assert result is None
|
|
payload = json.loads(ws.send_text.await_args.args[0])
|
|
assert "No model provided" in payload["error"]["message"]
|
|
ws.close.assert_awaited_once_with(code=1008, reason="No model provided")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_valid_first_frame_returns_model_and_raw(self):
|
|
from litellm.proxy.response_api_endpoints.endpoints import (
|
|
_read_ws_model_from_first_frame,
|
|
)
|
|
|
|
raw = json.dumps({"type": "response.create", "model": "gpt-4o", "input": []})
|
|
ws = MagicMock()
|
|
ws.receive_text = AsyncMock(return_value=raw)
|
|
ws.send_text = AsyncMock()
|
|
ws.close = AsyncMock()
|
|
|
|
result = await _read_ws_model_from_first_frame(ws)
|
|
|
|
assert result == ("gpt-4o", raw)
|
|
ws.send_text.assert_not_awaited()
|
|
ws.close.assert_not_awaited()
|
|
|
|
|
|
class TestManagedResponsesSameProvider:
|
|
def _handler(self, model, custom_llm_provider=None):
|
|
from litellm.responses.streaming_iterator import (
|
|
ManagedResponsesWebSocketHandler,
|
|
)
|
|
|
|
return ManagedResponsesWebSocketHandler(
|
|
websocket=MagicMock(),
|
|
model=model,
|
|
logging_obj=MagicMock(),
|
|
custom_llm_provider=custom_llm_provider,
|
|
)
|
|
|
|
def test_none_model_treated_as_same_provider(self):
|
|
assert self._handler("openai/gpt-4o")._same_provider(None) is True
|
|
|
|
def test_identical_model_is_same_provider(self):
|
|
assert self._handler("openai/gpt-4o")._same_provider("openai/gpt-4o") is True
|
|
|
|
def test_same_provider_different_model(self):
|
|
assert self._handler("gpt-4o")._same_provider("gpt-4o-mini") is True
|
|
|
|
def test_different_provider_is_not_same(self):
|
|
assert (
|
|
self._handler("gpt-4o")._same_provider("vertex_ai/gemini-2.0-flash")
|
|
is False
|
|
)
|
|
|
|
def test_inject_credentials_keeps_provider_for_same_provider_model(self):
|
|
handler = self._handler("gpt-4o", custom_llm_provider="openai")
|
|
call_kwargs: dict = {}
|
|
handler._inject_credentials(call_kwargs, model="gpt-4o-mini")
|
|
assert call_kwargs["custom_llm_provider"] == "openai"
|
|
|
|
def test_inject_credentials_drops_provider_for_cross_provider_model(self):
|
|
handler = self._handler("gpt-4o", custom_llm_provider="openai")
|
|
call_kwargs: dict = {}
|
|
handler._inject_credentials(call_kwargs, model="vertex_ai/gemini-2.0-flash")
|
|
assert "custom_llm_provider" not in call_kwargs
|
|
|
|
def test_unresolvable_connection_model_falls_back_to_custom_provider(self):
|
|
handler = self._handler(
|
|
"my-custom-deployment", custom_llm_provider="openai"
|
|
)
|
|
assert handler._same_provider("gpt-4o-mini") is True
|
|
call_kwargs: dict = {}
|
|
handler._inject_credentials(call_kwargs, model="gpt-4o-mini")
|
|
assert call_kwargs["custom_llm_provider"] == "openai"
|
|
|
|
def test_unresolvable_connection_model_still_drops_cross_provider(self):
|
|
handler = self._handler(
|
|
"my-custom-deployment", custom_llm_provider="openai"
|
|
)
|
|
call_kwargs: dict = {}
|
|
handler._inject_credentials(call_kwargs, model="vertex_ai/gemini-2.0-flash")
|
|
assert "custom_llm_provider" not in call_kwargs
|