Files
litellm/tests/test_litellm/proxy/test_response_model_sanitization.py
T
2026-04-17 13:02:59 -07:00

364 lines
11 KiB
Python

import asyncio
import json
import os
import sys
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import pytest
import yaml
from fastapi.testclient import TestClient
sys.path.insert(0, os.path.abspath("../../.."))
import litellm
pytestmark = pytest.mark.flaky(condition=False)
def _initialize_proxy_with_config(config: dict, tmp_path) -> TestClient:
"""
Initialize the proxy server with a temporary config file and return a TestClient.
IMPORTANT: proxy_server.initialize() mutates module-level globals. We must call
cleanup_router_config_variables() before initializing to prevent cross-test bleed.
"""
from litellm.proxy.proxy_server import (
app,
cleanup_router_config_variables,
initialize,
)
cleanup_router_config_variables()
config_fp = tmp_path / "proxy_config.yaml"
config_fp.write_text(yaml.safe_dump(config))
asyncio.run(initialize(config=str(config_fp), debug=True))
return TestClient(app)
def _make_minimal_chat_completion_response(model: str) -> litellm.ModelResponse:
response = litellm.ModelResponse()
response.model = model
response.choices[0].message.content = "hello" # type: ignore[union-attr]
response.choices[0].finish_reason = "stop" # type: ignore[union-attr]
return response
def _make_model_response_stream_chunk(model: str) -> litellm.ModelResponseStream:
"""
Create a minimal OpenAI-compatible chat.completion.chunk object.
"""
chunk_dict = {
"id": "chatcmpl-test",
"object": "chat.completion.chunk",
"created": 0,
"model": model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "hi"},
"finish_reason": None,
}
],
}
return litellm.ModelResponseStream(**chunk_dict)
def test_proxy_chat_completion_does_not_return_provider_prefixed_model(
tmp_path, monkeypatch
):
"""
Regression test:
- Client asks for `model="vllm-model"` (no provider prefix)
- Internal provider path uses `hosted_vllm/...`
- Proxy should not leak `hosted_vllm/` in the client-facing `model` field.
"""
client_model = "vllm-model"
internal_model = f"hosted_vllm/{client_model}"
client = _initialize_proxy_with_config(
config={
"general_settings": {"master_key": "sk-1234"},
"model_list": [
{
"model_name": client_model,
"litellm_params": {"model": internal_model},
}
],
},
tmp_path=tmp_path,
)
# Patch router call to avoid making any real network request.
from litellm.proxy import proxy_server
monkeypatch.setattr(
proxy_server.llm_router, # type: ignore[arg-type]
"acompletion",
AsyncMock(
return_value=_make_minimal_chat_completion_response(model=internal_model)
),
)
# Also no-op proxy logging hooks to keep this test focused and deterministic.
monkeypatch.setattr(
proxy_server.proxy_logging_obj, "during_call_hook", AsyncMock(return_value=None)
)
monkeypatch.setattr(
proxy_server.proxy_logging_obj,
"update_request_status",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
proxy_server.proxy_logging_obj,
"post_call_success_hook",
AsyncMock(side_effect=lambda **kwargs: kwargs["response"]),
)
resp = client.post(
"/v1/chat/completions",
headers={"Authorization": "Bearer sk-1234"},
json={"model": client_model, "messages": [{"role": "user", "content": "hi"}]},
)
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["model"] == client_model
assert not body["model"].startswith("hosted_vllm/")
@pytest.mark.asyncio
async def test_proxy_streaming_chunks_do_not_return_provider_prefixed_model(
monkeypatch,
):
"""
Regression test for streaming:
Even if a streaming chunk contains `model="hosted_vllm/<...>"`, the proxy SSE layer
should not leak the provider prefix to the client.
"""
client_model = "vllm-model"
internal_model = f"hosted_vllm/{client_model}"
from litellm.proxy import proxy_server
from litellm.proxy._types import UserAPIKeyAuth
# Patch proxy_logging_obj hooks so async_data_generator yields exactly our chunk.
async def _iterator_hook(
user_api_key_dict: UserAPIKeyAuth,
response: AsyncGenerator,
request_data: dict,
):
yield _make_model_response_stream_chunk(model=internal_model)
monkeypatch.setattr(
proxy_server.proxy_logging_obj,
"async_post_call_streaming_iterator_hook",
_iterator_hook,
)
monkeypatch.setattr(
proxy_server.proxy_logging_obj,
"async_post_call_streaming_hook",
AsyncMock(side_effect=lambda **kwargs: kwargs["response"]),
)
user_api_key_dict = UserAPIKeyAuth(api_key="sk-1234")
gen = proxy_server.async_data_generator(
response=MagicMock(),
user_api_key_dict=user_api_key_dict,
request_data={"model": client_model},
)
chunks = []
async for item in gen:
chunks.append(item)
# First chunk is expected to be JSON, last chunk is [DONE]
assert len(chunks) >= 2
first = chunks[0]
assert first.startswith("data: ")
payload = json.loads(first[len("data: ") :].strip())
assert payload["model"] == client_model
assert not payload["model"].startswith("hosted_vllm/")
@pytest.mark.asyncio
async def test_proxy_streaming_chunks_use_client_requested_model_before_alias_mapping(
monkeypatch,
):
"""
Regression test for alias mapping on streaming:
- `common_processing_pre_call_logic` can rewrite `request_data["model"]` via model_alias_map / key-specific aliases.
- Non-streaming responses are restamped using the original client-requested model (captured before the rewrite).
- Streaming chunks must do the same to avoid mismatched `model` values between streaming and non-streaming.
"""
client_model_alias = "alias-model"
canonical_model = "vllm-model"
internal_model = f"hosted_vllm/{canonical_model}"
from litellm.proxy import proxy_server
from litellm.proxy._types import UserAPIKeyAuth
async def _iterator_hook(
user_api_key_dict: UserAPIKeyAuth,
response: AsyncGenerator,
request_data: dict,
):
yield _make_model_response_stream_chunk(model=internal_model)
monkeypatch.setattr(
proxy_server.proxy_logging_obj,
"async_post_call_streaming_iterator_hook",
_iterator_hook,
)
monkeypatch.setattr(
proxy_server.proxy_logging_obj,
"async_post_call_streaming_hook",
AsyncMock(side_effect=lambda **kwargs: kwargs["response"]),
)
user_api_key_dict = UserAPIKeyAuth(api_key="sk-1234")
gen = proxy_server.async_data_generator(
response=MagicMock(),
user_api_key_dict=user_api_key_dict,
request_data={
"model": canonical_model,
"_litellm_client_requested_model": client_model_alias,
},
)
chunks = []
async for item in gen:
chunks.append(item)
assert len(chunks) >= 2
first = chunks[0]
assert first.startswith("data: ")
payload = json.loads(first[len("data: ") :].strip())
assert payload["model"] == client_model_alias
assert not payload["model"].startswith("hosted_vllm/")
@pytest.mark.asyncio
async def test_proxy_streaming_azure_model_router_preserves_actual_model(monkeypatch):
"""
Regression test for Azure Model Router streaming:
When the client requests azure_ai/model_router, the streaming chunks should
preserve the actual model used (e.g., azure_ai/gpt-5-nano-2025-08-07) from
the downstream response, NOT override to the router model.
"""
router_model = "azure_ai/model_router"
actual_model_used = "azure_ai/gpt-5-nano-2025-08-07"
from litellm.proxy import proxy_server
from litellm.proxy._types import UserAPIKeyAuth
async def _iterator_hook(
user_api_key_dict: UserAPIKeyAuth,
response: AsyncGenerator,
request_data: dict,
):
yield _make_model_response_stream_chunk(model=actual_model_used)
monkeypatch.setattr(
proxy_server.proxy_logging_obj,
"async_post_call_streaming_iterator_hook",
_iterator_hook,
)
monkeypatch.setattr(
proxy_server.proxy_logging_obj,
"async_post_call_streaming_hook",
AsyncMock(side_effect=lambda **kwargs: kwargs["response"]),
)
user_api_key_dict = UserAPIKeyAuth(api_key="sk-1234")
gen = proxy_server.async_data_generator(
response=MagicMock(),
user_api_key_dict=user_api_key_dict,
request_data={
"model": router_model,
"_litellm_client_requested_model": router_model,
},
)
chunks = []
async for item in gen:
chunks.append(item)
assert len(chunks) >= 2
first = chunks[0]
assert first.startswith("data: ")
payload = json.loads(first[len("data: ") :].strip())
# Azure Model Router: preserve actual model used, not the router model
assert payload["model"] == actual_model_used
assert payload["model"] != router_model
@pytest.mark.asyncio
async def test_proxy_streaming_fastest_response_preserves_winning_model(monkeypatch):
"""
Regression test for fastest_response streaming:
When the client sends a comma-separated model list with fastest_response=True,
the streaming chunks should preserve the winning model's name from the
downstream response, NOT override to the comma-separated list.
"""
comma_separated_models = "openai/gpt-4o,gemini/gemini-2.5-flash"
winning_model = "gemini-2.5-flash"
from litellm.proxy import proxy_server
from litellm.proxy._types import UserAPIKeyAuth
async def _iterator_hook(
user_api_key_dict: UserAPIKeyAuth,
response: AsyncGenerator,
request_data: dict,
):
yield _make_model_response_stream_chunk(model=winning_model)
monkeypatch.setattr(
proxy_server.proxy_logging_obj,
"async_post_call_streaming_iterator_hook",
_iterator_hook,
)
monkeypatch.setattr(
proxy_server.proxy_logging_obj,
"async_post_call_streaming_hook",
AsyncMock(side_effect=lambda **kwargs: kwargs["response"]),
)
user_api_key_dict = UserAPIKeyAuth(api_key="sk-1234")
gen = proxy_server.async_data_generator(
response=MagicMock(),
user_api_key_dict=user_api_key_dict,
request_data={
"model": comma_separated_models,
"_litellm_client_requested_model": comma_separated_models,
"fastest_response": True,
},
)
chunks = []
async for item in gen:
chunks.append(item)
assert len(chunks) >= 2
first = chunks[0]
assert first.startswith("data: ")
payload = json.loads(first[len("data: ") :].strip())
assert payload["model"] == winning_model
assert payload["model"] != comma_separated_models