mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-22 21:35:46 +00:00
ba17f51812
* fix(proxy): prevent provider-prefixed model leaks Proxy clients should not see LiteLLM internal provider prefixes (e.g. hosted_vllm/...) in the OpenAI-compatible response model field. This patch sanitizes the client-facing model name for both: - Non-streaming responses returned from base_process_llm_request - Streaming SSE chunks emitted by async_data_generator Adds regression tests covering vLLM-style hosted_vllm routing for both streaming and non-streaming paths. * chore(lint): suppress PLR0915 in proxy handler Ruff started flagging ProxyBaseLLMRequestProcessing.base_process_llm_request() for too many statements after the hotpatch changes. Add an explicit '# noqa: PLR0915' on the function definition to avoid a large refactor in a hotpatch. * refactor(proxy): make model restamp explicit Replace silent try/except/pass and type ignores with explicit model restamping. - Logs an error when the downstream response model differs from the client-requested model - Overwrites the OpenAI `model` field to the client-requested value to avoid leaking internal provider-prefixed identifiers - Applies the same behavior to streaming chunks, logging the mismatch only once per stream * chore(lint): drop PLR0915 suppression The model restamping bugfix made `base_process_llm_request()` slightly exceed Ruff's PLR0915 (too-many-statements) threshold, requiring a `# noqa` suppression. Collapse consecutive `hidden_params` extractions into tuple unpacking so the function falls back under the lint limit and remove the suppression. No functional change intended; this keeps the proxy model-field bugfix intact while aligning with project linting rules. * chore(proxy): log model mismatches as warnings These model-restamping logs are intentionally verbose: a mismatch is a useful signal that an internal provider/deployment identifier may be leaking into the public OpenAI response `model` field. - Downgrade model mismatch logs from error -> warning - Keep error logs only for cases where the proxy cannot read/override the model * fix(proxy): preserve client model for streaming aliasing Pre-call processing can rewrite request_data['model'] via model alias maps.\n\nOur streaming SSE generator was using the rewritten value when restamping chunk.model, which caused the public 'model' field to differ between streaming and non-streaming responses for alias-based requests.\n\nStash the original client model in request_data as _litellm_client_requested_model after the model has been routed, and prefer it when overriding the outgoing chunk model. Add a regression test for the alias-mapping case. * chore(lint): satisfy PLR0915 in streaming generator Ruff started flagging async_data_generator() for too many statements after adding model restamping logic.\n\nExtract the client-model selection + chunk restamping into small helpers to keep behavior unchanged while meeting the project's PLR0915 threshold.
218 lines
7.2 KiB
Python
218 lines
7.2 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
from typing import AsyncGenerator
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
import yaml
|
|
from fastapi.testclient import TestClient
|
|
|
|
sys.path.insert(0, os.path.abspath("../../.."))
|
|
|
|
import litellm
|
|
|
|
pytestmark = pytest.mark.flaky(condition=False)
|
|
|
|
|
|
def _initialize_proxy_with_config(config: dict, tmp_path) -> TestClient:
|
|
"""
|
|
Initialize the proxy server with a temporary config file and return a TestClient.
|
|
|
|
IMPORTANT: proxy_server.initialize() mutates module-level globals. We must call
|
|
cleanup_router_config_variables() before initializing to prevent cross-test bleed.
|
|
"""
|
|
from litellm.proxy.proxy_server import app, cleanup_router_config_variables, initialize
|
|
|
|
cleanup_router_config_variables()
|
|
|
|
config_fp = tmp_path / "proxy_config.yaml"
|
|
config_fp.write_text(yaml.safe_dump(config))
|
|
|
|
asyncio.run(initialize(config=str(config_fp), debug=True))
|
|
return TestClient(app)
|
|
|
|
|
|
def _make_minimal_chat_completion_response(model: str) -> litellm.ModelResponse:
|
|
response = litellm.ModelResponse()
|
|
response.model = model
|
|
response.choices[0].message.content = "hello" # type: ignore[union-attr]
|
|
response.choices[0].finish_reason = "stop" # type: ignore[union-attr]
|
|
return response
|
|
|
|
|
|
def _make_model_response_stream_chunk(model: str) -> litellm.ModelResponseStream:
|
|
"""
|
|
Create a minimal OpenAI-compatible chat.completion.chunk object.
|
|
"""
|
|
chunk_dict = {
|
|
"id": "chatcmpl-test",
|
|
"object": "chat.completion.chunk",
|
|
"created": 0,
|
|
"model": model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant", "content": "hi"},
|
|
"finish_reason": None,
|
|
}
|
|
],
|
|
}
|
|
return litellm.ModelResponseStream(**chunk_dict)
|
|
|
|
|
|
def test_proxy_chat_completion_does_not_return_provider_prefixed_model(tmp_path, monkeypatch):
|
|
"""
|
|
Regression test:
|
|
|
|
- Client asks for `model="vllm-model"` (no provider prefix)
|
|
- Internal provider path uses `hosted_vllm/...`
|
|
- Proxy should not leak `hosted_vllm/` in the client-facing `model` field.
|
|
"""
|
|
client_model = "vllm-model"
|
|
internal_model = f"hosted_vllm/{client_model}"
|
|
|
|
client = _initialize_proxy_with_config(
|
|
config={
|
|
"general_settings": {"master_key": "sk-1234"},
|
|
"model_list": [
|
|
{
|
|
"model_name": client_model,
|
|
"litellm_params": {"model": internal_model},
|
|
}
|
|
],
|
|
},
|
|
tmp_path=tmp_path,
|
|
)
|
|
|
|
# Patch router call to avoid making any real network request.
|
|
from litellm.proxy import proxy_server
|
|
|
|
monkeypatch.setattr(
|
|
proxy_server.llm_router, # type: ignore[arg-type]
|
|
"acompletion",
|
|
AsyncMock(return_value=_make_minimal_chat_completion_response(model=internal_model)),
|
|
)
|
|
|
|
# Also no-op proxy logging hooks to keep this test focused and deterministic.
|
|
monkeypatch.setattr(proxy_server.proxy_logging_obj, "during_call_hook", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(proxy_server.proxy_logging_obj, "update_request_status", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(proxy_server.proxy_logging_obj, "post_call_success_hook", AsyncMock(side_effect=lambda **kwargs: kwargs["response"]))
|
|
|
|
resp = client.post(
|
|
"/v1/chat/completions",
|
|
headers={"Authorization": "Bearer sk-1234"},
|
|
json={"model": client_model, "messages": [{"role": "user", "content": "hi"}]},
|
|
)
|
|
|
|
assert resp.status_code == 200, resp.text
|
|
body = resp.json()
|
|
assert body["model"] == client_model
|
|
assert not body["model"].startswith("hosted_vllm/")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_streaming_chunks_do_not_return_provider_prefixed_model(monkeypatch):
|
|
"""
|
|
Regression test for streaming:
|
|
|
|
Even if a streaming chunk contains `model="hosted_vllm/<...>"`, the proxy SSE layer
|
|
should not leak the provider prefix to the client.
|
|
"""
|
|
client_model = "vllm-model"
|
|
internal_model = f"hosted_vllm/{client_model}"
|
|
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy import proxy_server
|
|
|
|
# Patch proxy_logging_obj hooks so async_data_generator yields exactly our chunk.
|
|
async def _iterator_hook(
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
response: AsyncGenerator,
|
|
request_data: dict,
|
|
):
|
|
yield _make_model_response_stream_chunk(model=internal_model)
|
|
|
|
monkeypatch.setattr(proxy_server.proxy_logging_obj, "async_post_call_streaming_iterator_hook", _iterator_hook)
|
|
monkeypatch.setattr(
|
|
proxy_server.proxy_logging_obj,
|
|
"async_post_call_streaming_hook",
|
|
AsyncMock(side_effect=lambda **kwargs: kwargs["response"]),
|
|
)
|
|
|
|
user_api_key_dict = UserAPIKeyAuth(api_key="sk-1234")
|
|
|
|
gen = proxy_server.async_data_generator(
|
|
response=MagicMock(),
|
|
user_api_key_dict=user_api_key_dict,
|
|
request_data={"model": client_model},
|
|
)
|
|
|
|
chunks = []
|
|
async for item in gen:
|
|
chunks.append(item)
|
|
|
|
# First chunk is expected to be JSON, last chunk is [DONE]
|
|
assert len(chunks) >= 2
|
|
first = chunks[0]
|
|
assert first.startswith("data: ")
|
|
|
|
payload = json.loads(first[len("data: ") :].strip())
|
|
assert payload["model"] == client_model
|
|
assert not payload["model"].startswith("hosted_vllm/")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_streaming_chunks_use_client_requested_model_before_alias_mapping(monkeypatch):
|
|
"""
|
|
Regression test for alias mapping on streaming:
|
|
|
|
- `common_processing_pre_call_logic` can rewrite `request_data["model"]` via model_alias_map / key-specific aliases.
|
|
- Non-streaming responses are restamped using the original client-requested model (captured before the rewrite).
|
|
- Streaming chunks must do the same to avoid mismatched `model` values between streaming and non-streaming.
|
|
"""
|
|
client_model_alias = "alias-model"
|
|
canonical_model = "vllm-model"
|
|
internal_model = f"hosted_vllm/{canonical_model}"
|
|
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy import proxy_server
|
|
|
|
async def _iterator_hook(
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
response: AsyncGenerator,
|
|
request_data: dict,
|
|
):
|
|
yield _make_model_response_stream_chunk(model=internal_model)
|
|
|
|
monkeypatch.setattr(proxy_server.proxy_logging_obj, "async_post_call_streaming_iterator_hook", _iterator_hook)
|
|
monkeypatch.setattr(
|
|
proxy_server.proxy_logging_obj,
|
|
"async_post_call_streaming_hook",
|
|
AsyncMock(side_effect=lambda **kwargs: kwargs["response"]),
|
|
)
|
|
|
|
user_api_key_dict = UserAPIKeyAuth(api_key="sk-1234")
|
|
|
|
gen = proxy_server.async_data_generator(
|
|
response=MagicMock(),
|
|
user_api_key_dict=user_api_key_dict,
|
|
request_data={
|
|
"model": canonical_model,
|
|
"_litellm_client_requested_model": client_model_alias,
|
|
},
|
|
)
|
|
|
|
chunks = []
|
|
async for item in gen:
|
|
chunks.append(item)
|
|
|
|
assert len(chunks) >= 2
|
|
first = chunks[0]
|
|
assert first.startswith("data: ")
|
|
|
|
payload = json.loads(first[len("data: ") :].strip())
|
|
assert payload["model"] == client_model_alias
|
|
assert not payload["model"].startswith("hosted_vllm/")
|