mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 00:48:01 +00:00
bb3b77cbeb
* feat: add guardrail for image generation * fix: use get_str_from_messages
107 lines
3.4 KiB
Python
107 lines
3.4 KiB
Python
import asyncio
|
|
import copy
|
|
from types import SimpleNamespace
|
|
from typing import Any, Dict
|
|
|
|
import orjson
|
|
import pytest
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.image_endpoints import endpoints
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_image_generation_prompt_rerouting(monkeypatch):
|
|
"""Ensure image prompts are exposed to guardrails and restored afterwards."""
|
|
|
|
async def fake_add_litellm_data_to_request(**kwargs):
|
|
return kwargs["data"]
|
|
|
|
async def fake_update_request_status(**_: Any) -> None:
|
|
await asyncio.sleep(0)
|
|
|
|
proxy_logger_calls: Dict[str, Any] = {}
|
|
|
|
async def fake_pre_call_hook(*, user_api_key_dict, data, call_type): # type: ignore[override]
|
|
proxy_logger_calls["pre_call_input"] = copy.deepcopy(data)
|
|
modified = {
|
|
**data,
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "sanitized prompt",
|
|
}
|
|
],
|
|
}
|
|
return modified
|
|
|
|
async def fake_post_call_failure_hook(**_: Any) -> None:
|
|
return None
|
|
|
|
fake_proxy_logger = SimpleNamespace(
|
|
pre_call_hook=fake_pre_call_hook,
|
|
update_request_status=fake_update_request_status,
|
|
post_call_failure_hook=fake_post_call_failure_hook,
|
|
)
|
|
|
|
captured_route_request_data: Dict[str, Any] = {}
|
|
|
|
async def fake_route_request(*, data, **kwargs): # type: ignore[override]
|
|
captured_route_request_data.update(data)
|
|
|
|
async def _inner():
|
|
class FakeResponse(dict):
|
|
_hidden_params = {}
|
|
|
|
return FakeResponse(result="ok")
|
|
|
|
return _inner()
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"method": "POST",
|
|
"path": "/v1/images/generations",
|
|
"headers": [],
|
|
}
|
|
body = orjson.dumps({"prompt": "original prompt"})
|
|
|
|
async def receive():
|
|
return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
request = Request(scope, receive)
|
|
response = Response()
|
|
user_api_key = UserAPIKeyAuth()
|
|
|
|
monkeypatch.setattr(
|
|
"litellm.proxy.proxy_server.add_litellm_data_to_request",
|
|
fake_add_litellm_data_to_request,
|
|
)
|
|
monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {})
|
|
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", None)
|
|
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", {})
|
|
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_logging_obj", fake_proxy_logger)
|
|
monkeypatch.setattr("litellm.proxy.proxy_server.user_model", None)
|
|
monkeypatch.setattr("litellm.proxy.proxy_server.version", "test-version")
|
|
monkeypatch.setattr(
|
|
"litellm.proxy.common_request_processing.ProxyBaseLLMRequestProcessing.get_custom_headers",
|
|
classmethod(lambda *args, **kwargs: {}),
|
|
)
|
|
monkeypatch.setattr(
|
|
"litellm.proxy.image_endpoints.endpoints.route_request", fake_route_request
|
|
)
|
|
|
|
result = await endpoints.image_generation(
|
|
request=request,
|
|
fastapi_response=response,
|
|
user_api_key_dict=user_api_key,
|
|
)
|
|
await asyncio.sleep(0)
|
|
|
|
assert result == {"result": "ok"}
|
|
pre_call_input = proxy_logger_calls["pre_call_input"]
|
|
assert pre_call_input["messages"][0]["content"] == "original prompt"
|
|
assert captured_route_request_data["prompt"] == "sanitized prompt"
|
|
assert "messages" not in captured_route_request_data
|