diff --git a/docs/my-website/docs/providers/litellm_proxy.md b/docs/my-website/docs/providers/litellm_proxy.md index efe61e1544..bfefc8a787 100644 --- a/docs/my-website/docs/providers/litellm_proxy.md +++ b/docs/my-website/docs/providers/litellm_proxy.md @@ -9,7 +9,7 @@ import TabItem from '@theme/TabItem'; | Description | LiteLLM Proxy is an OpenAI-compatible gateway that allows you to interact with multiple LLM providers through a unified API. Simply use the `litellm_proxy/` prefix before the model name to route your requests through the proxy. | | Provider Route on LiteLLM | `litellm_proxy/` (add this prefix to the model name, to route any requests to litellm_proxy - e.g. `litellm_proxy/your-model-name`) | | Setup LiteLLM Gateway | [LiteLLM Gateway ↗](../simple_proxy) | -| Supported Endpoints |`/chat/completions`, `/completions`, `/embeddings`, `/audio/speech`, `/audio/transcriptions`, `/images`, `/rerank` | +| Supported Endpoints |`/chat/completions`, `/completions`, `/embeddings`, `/audio/speech`, `/audio/transcriptions`, `/images`, `/images/edits`, `/rerank` | @@ -111,6 +111,21 @@ response = litellm.image_generation( ) ``` +## Image Edit + +```python +import litellm + +with open("your-image.png", "rb") as f: + response = litellm.image_edit( + model="litellm_proxy/gpt-image-1", + prompt="Make this image a watercolor painting", + image=[f], + api_base="your-litellm-proxy-url", + api_key="your-litellm-proxy-api-key", + ) +``` + ## Audio Transcription ```python diff --git a/litellm/images/main.py b/litellm/images/main.py index ca14fabd1f..70d9eb41dd 100644 --- a/litellm/images/main.py +++ b/litellm/images/main.py @@ -369,6 +369,7 @@ def image_generation( # noqa: PLR0915 ) elif ( custom_llm_provider == "openai" + or custom_llm_provider == LlmProviders.LITELLM_PROXY.value or custom_llm_provider in litellm.openai_compatible_providers ): model_response = openai_chat_completions.image_generation( @@ -444,7 +445,6 @@ def image_generation( # noqa: PLR0915 elif custom_llm_provider in ( litellm.LlmProviders.RECRAFT, litellm.LlmProviders.GEMINI, - ): if image_generation_config is None: raise ValueError(f"image generation config is not supported for {custom_llm_provider}") diff --git a/litellm/llms/litellm_proxy/image_edit/transformation.py b/litellm/llms/litellm_proxy/image_edit/transformation.py new file mode 100644 index 0000000000..5f5e2bdb24 --- /dev/null +++ b/litellm/llms/litellm_proxy/image_edit/transformation.py @@ -0,0 +1,26 @@ +from typing import Optional + +from litellm.llms.openai.image_edit.transformation import OpenAIImageEditConfig +from litellm.secret_managers.main import get_secret_str + + +class LiteLLMProxyImageEditConfig(OpenAIImageEditConfig): + """Configuration for image edit requests routed through LiteLLM Proxy.""" + + def validate_environment( + self, headers: dict, model: str, api_key: Optional[str] = None + ) -> dict: + api_key = api_key or get_secret_str("LITELLM_PROXY_API_KEY") + headers.update({"Authorization": f"Bearer {api_key}"}) + return headers + + def get_complete_url( + self, model: str, api_base: Optional[str], litellm_params: dict + ) -> str: + api_base = api_base or get_secret_str("LITELLM_PROXY_API_BASE") + if api_base is None: + raise ValueError( + "api_base not set for LiteLLM Proxy route. Set in env via `LITELLM_PROXY_API_BASE`" + ) + api_base = api_base.rstrip("/") + return f"{api_base}/images/edits" diff --git a/litellm/llms/litellm_proxy/image_generation/transformation.py b/litellm/llms/litellm_proxy/image_generation/transformation.py new file mode 100644 index 0000000000..0ac2b396b0 --- /dev/null +++ b/litellm/llms/litellm_proxy/image_generation/transformation.py @@ -0,0 +1,41 @@ +from typing import List, Optional + +from litellm.llms.openai.image_generation.gpt_transformation import ( + GPTImageGenerationConfig, +) +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams + + +class LiteLLMProxyImageGenerationConfig(GPTImageGenerationConfig): + """Configuration for image generation requests routed through LiteLLM Proxy.""" + def validate_environment( + self, + headers: dict, + model: str, + messages, + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + api_key = api_key or get_secret_str("LITELLM_PROXY_API_KEY") + headers.update({"Authorization": f"Bearer {api_key}"}) + return headers + + def get_complete_url( + self, + api_base: Optional[str], + api_key: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + api_base = api_base or get_secret_str("LITELLM_PROXY_API_BASE") + if api_base is None: + raise ValueError( + "api_base not set for LiteLLM Proxy route. Set in env via `LITELLM_PROXY_API_BASE`" + ) + api_base = api_base.rstrip("/") + return f"{api_base}/images/generations" diff --git a/litellm/utils.py b/litellm/utils.py index 2ae038e6e7..05055b7df8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7305,6 +7305,12 @@ class ProviderConfigManager: ) return get_gemini_image_generation_config(model) + elif LlmProviders.LITELLM_PROXY == provider: + from litellm.llms.litellm_proxy.image_generation.transformation import ( + LiteLLMProxyImageGenerationConfig, + ) + + return LiteLLMProxyImageGenerationConfig() return None @staticmethod @@ -7341,6 +7347,12 @@ class ProviderConfigManager: ) return RecraftImageEditConfig() + elif LlmProviders.LITELLM_PROXY == provider: + from litellm.llms.litellm_proxy.image_edit.transformation import ( + LiteLLMProxyImageEditConfig, + ) + + return LiteLLMProxyImageEditConfig() return None @staticmethod diff --git a/tests/llm_translation/test_litellm_proxy_provider.py b/tests/llm_translation/test_litellm_proxy_provider.py index 1ae507a10d..1a09441979 100644 --- a/tests/llm_translation/test_litellm_proxy_provider.py +++ b/tests/llm_translation/test_litellm_proxy_provider.py @@ -2,6 +2,7 @@ import json import os import sys from datetime import datetime +from io import BytesIO from unittest.mock import AsyncMock sys.path.insert( @@ -184,6 +185,127 @@ async def test_litellm_gateway_from_sdk_image_generation(is_async): assert "dall-e-3" == mock_method.call_args.kwargs["model"] +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.asyncio +async def test_litellm_gateway_image_generation_direct(is_async): + """Test image generation using the litellm_proxy provider directly.""" + litellm._turn_on_debug() + + # Create mock response that matches OpenAI's response structure + mock_openai_response = MagicMock() + mock_openai_response.model_dump.return_value = { + "created": 1, + "data": [{"url": "https://example.com/image.png"}], + } + + if is_async: + # Mock the AsyncOpenAI client that gets created inside _get_openai_client + mock_async_client = AsyncMock() + mock_async_client.images.generate = AsyncMock(return_value=mock_openai_response) + + with patch("litellm.llms.openai.openai.AsyncOpenAI", return_value=mock_async_client) as mock_async_constructor: + response = await litellm.aimage_generation( + model="litellm_proxy/dall-e-3", + prompt="A beautiful sunset over mountains", + api_base="http://my-proxy", + api_key="sk-1234", + ) + + # Verify the AsyncOpenAI client constructor was called with correct parameters + mock_async_constructor.assert_called_once() + constructor_kwargs = mock_async_constructor.call_args.kwargs + print("KWARGS to Async OpenAI constructor=", constructor_kwargs) + assert constructor_kwargs["api_key"] == "sk-1234" + assert constructor_kwargs["base_url"] == "http://my-proxy" + + # Verify the AsyncOpenAI client was called correctly + mock_async_client.images.generate.assert_awaited_once() + call_kwargs = mock_async_client.images.generate.call_args.kwargs + assert call_kwargs["model"] == "dall-e-3" + assert call_kwargs["prompt"] == "A beautiful sunset over mountains" + else: + # Mock the sync OpenAI client that gets created inside _get_openai_client + mock_sync_client = MagicMock() + mock_sync_client.images.generate.return_value = mock_openai_response + + with patch("litellm.llms.openai.openai.OpenAI", return_value=mock_sync_client) as mock_sync_constructor: + response = litellm.image_generation( + model="litellm_proxy/dall-e-3", + prompt="A beautiful sunset over mountains", + api_base="http://my-proxy", + api_key="sk-1234", + ) + + # Verify the OpenAI client constructor was called with correct parameters + mock_sync_constructor.assert_called_once() + constructor_kwargs = mock_sync_constructor.call_args.kwargs + assert constructor_kwargs["api_key"] == "sk-1234" + assert constructor_kwargs["base_url"] == "http://my-proxy" + + # Verify the OpenAI client was called correctly + mock_sync_client.images.generate.assert_called_once() + call_kwargs = mock_sync_client.images.generate.call_args.kwargs + assert call_kwargs["model"] == "dall-e-3" + assert call_kwargs["prompt"] == "A beautiful sunset over mountains" + + # Verify the response structure + assert response is not None + assert hasattr(response, 'data') or isinstance(response, dict) + + +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.asyncio +async def test_litellm_gateway_from_sdk_image_edit(is_async): + litellm._turn_on_debug() + + mock_response = { + "created": 1, + "data": [{"b64_json": ""}], + } + + class MockResponse: + def __init__(self, json_data, status_code): + self._json_data = json_data + self.status_code = status_code + self.text = json.dumps(json_data) + + def json(self): + return self._json_data + + image_file = BytesIO(b"fake-image") + + if is_async: + mock_post = AsyncMock(return_value=MockResponse(mock_response, 200)) + patch_target = "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post" + else: + mock_post = MagicMock(return_value=MockResponse(mock_response, 200)) + patch_target = "litellm.llms.custom_httpx.http_handler.HTTPHandler.post" + + with patch(patch_target, new=mock_post): + if is_async: + await litellm.aimage_edit( + model="litellm_proxy/gpt-image-1", + prompt="A test prompt", + image=[image_file], + api_base="http://my-proxy", + api_key="sk-1234", + ) + mock_post.assert_awaited_once() + else: + litellm.image_edit( + model="litellm_proxy/gpt-image-1", + prompt="A test prompt", + image=[image_file], + api_base="http://my-proxy", + api_key="sk-1234", + ) + mock_post.assert_called_once() + + called_kwargs = mock_post.call_args.kwargs + assert called_kwargs["url"] == "http://my-proxy/images/edits" + assert called_kwargs["headers"]["Authorization"] == "Bearer sk-1234" + + @pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.asyncio async def test_litellm_gateway_from_sdk_transcription(is_async):