Files
litellm/tests/test_litellm/proxy/image_endpoints/test_endpoints.py
T
YutaSaito bb3b77cbeb feat: add guardrail for image generation (#15619)
* feat: add guardrail for image generation

* fix: use get_str_from_messages
2025-10-16 21:50:30 -07:00

107 lines
3.4 KiB
Python

import asyncio
import copy
from types import SimpleNamespace
from typing import Any, Dict
import orjson
import pytest
from starlette.requests import Request
from starlette.responses import Response
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.image_endpoints import endpoints
@pytest.mark.asyncio
async def test_image_generation_prompt_rerouting(monkeypatch):
"""Ensure image prompts are exposed to guardrails and restored afterwards."""
async def fake_add_litellm_data_to_request(**kwargs):
return kwargs["data"]
async def fake_update_request_status(**_: Any) -> None:
await asyncio.sleep(0)
proxy_logger_calls: Dict[str, Any] = {}
async def fake_pre_call_hook(*, user_api_key_dict, data, call_type): # type: ignore[override]
proxy_logger_calls["pre_call_input"] = copy.deepcopy(data)
modified = {
**data,
"messages": [
{
"role": "user",
"content": "sanitized prompt",
}
],
}
return modified
async def fake_post_call_failure_hook(**_: Any) -> None:
return None
fake_proxy_logger = SimpleNamespace(
pre_call_hook=fake_pre_call_hook,
update_request_status=fake_update_request_status,
post_call_failure_hook=fake_post_call_failure_hook,
)
captured_route_request_data: Dict[str, Any] = {}
async def fake_route_request(*, data, **kwargs): # type: ignore[override]
captured_route_request_data.update(data)
async def _inner():
class FakeResponse(dict):
_hidden_params = {}
return FakeResponse(result="ok")
return _inner()
scope = {
"type": "http",
"method": "POST",
"path": "/v1/images/generations",
"headers": [],
}
body = orjson.dumps({"prompt": "original prompt"})
async def receive():
return {"type": "http.request", "body": body, "more_body": False}
request = Request(scope, receive)
response = Response()
user_api_key = UserAPIKeyAuth()
monkeypatch.setattr(
"litellm.proxy.proxy_server.add_litellm_data_to_request",
fake_add_litellm_data_to_request,
)
monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {})
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", None)
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", {})
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_logging_obj", fake_proxy_logger)
monkeypatch.setattr("litellm.proxy.proxy_server.user_model", None)
monkeypatch.setattr("litellm.proxy.proxy_server.version", "test-version")
monkeypatch.setattr(
"litellm.proxy.common_request_processing.ProxyBaseLLMRequestProcessing.get_custom_headers",
classmethod(lambda *args, **kwargs: {}),
)
monkeypatch.setattr(
"litellm.proxy.image_endpoints.endpoints.route_request", fake_route_request
)
result = await endpoints.image_generation(
request=request,
fastapi_response=response,
user_api_key_dict=user_api_key,
)
await asyncio.sleep(0)
assert result == {"result": "ok"}
pre_call_input = proxy_logger_calls["pre_call_input"]
assert pre_call_input["messages"][0]["content"] == "original prompt"
assert captured_route_request_data["prompt"] == "sanitized prompt"
assert "messages" not in captured_route_request_data