mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 05:28:02 +00:00
[Bug Fix] image_edit() function returns APIConnectionError with litellm_proxy - Support for both image edits and image generations (#13735)
* add image edits litellm proxy on SDK * add image gen provider * add IMG Gen support for litellm_proxy provider
This commit is contained in:
@@ -9,7 +9,7 @@ import TabItem from '@theme/TabItem';
|
||||
| Description | LiteLLM Proxy is an OpenAI-compatible gateway that allows you to interact with multiple LLM providers through a unified API. Simply use the `litellm_proxy/` prefix before the model name to route your requests through the proxy. |
|
||||
| Provider Route on LiteLLM | `litellm_proxy/` (add this prefix to the model name, to route any requests to litellm_proxy - e.g. `litellm_proxy/your-model-name`) |
|
||||
| Setup LiteLLM Gateway | [LiteLLM Gateway ↗](../simple_proxy) |
|
||||
| Supported Endpoints |`/chat/completions`, `/completions`, `/embeddings`, `/audio/speech`, `/audio/transcriptions`, `/images`, `/rerank` |
|
||||
| Supported Endpoints |`/chat/completions`, `/completions`, `/embeddings`, `/audio/speech`, `/audio/transcriptions`, `/images`, `/images/edits`, `/rerank` |
|
||||
|
||||
|
||||
|
||||
@@ -111,6 +111,21 @@ response = litellm.image_generation(
|
||||
)
|
||||
```
|
||||
|
||||
## Image Edit
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
with open("your-image.png", "rb") as f:
|
||||
response = litellm.image_edit(
|
||||
model="litellm_proxy/gpt-image-1",
|
||||
prompt="Make this image a watercolor painting",
|
||||
image=[f],
|
||||
api_base="your-litellm-proxy-url",
|
||||
api_key="your-litellm-proxy-api-key",
|
||||
)
|
||||
```
|
||||
|
||||
## Audio Transcription
|
||||
|
||||
```python
|
||||
|
||||
@@ -369,6 +369,7 @@ def image_generation( # noqa: PLR0915
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == LlmProviders.LITELLM_PROXY.value
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
):
|
||||
model_response = openai_chat_completions.image_generation(
|
||||
@@ -444,7 +445,6 @@ def image_generation( # noqa: PLR0915
|
||||
elif custom_llm_provider in (
|
||||
litellm.LlmProviders.RECRAFT,
|
||||
litellm.LlmProviders.GEMINI,
|
||||
|
||||
):
|
||||
if image_generation_config is None:
|
||||
raise ValueError(f"image generation config is not supported for {custom_llm_provider}")
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
from typing import Optional
|
||||
|
||||
from litellm.llms.openai.image_edit.transformation import OpenAIImageEditConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class LiteLLMProxyImageEditConfig(OpenAIImageEditConfig):
|
||||
"""Configuration for image edit requests routed through LiteLLM Proxy."""
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, model: str, api_key: Optional[str] = None
|
||||
) -> dict:
|
||||
api_key = api_key or get_secret_str("LITELLM_PROXY_API_KEY")
|
||||
headers.update({"Authorization": f"Bearer {api_key}"})
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self, model: str, api_base: Optional[str], litellm_params: dict
|
||||
) -> str:
|
||||
api_base = api_base or get_secret_str("LITELLM_PROXY_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base not set for LiteLLM Proxy route. Set in env via `LITELLM_PROXY_API_BASE`"
|
||||
)
|
||||
api_base = api_base.rstrip("/")
|
||||
return f"{api_base}/images/edits"
|
||||
@@ -0,0 +1,41 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.openai.image_generation.gpt_transformation import (
|
||||
GPTImageGenerationConfig,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
|
||||
|
||||
|
||||
class LiteLLMProxyImageGenerationConfig(GPTImageGenerationConfig):
|
||||
"""Configuration for image generation requests routed through LiteLLM Proxy."""
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = api_key or get_secret_str("LITELLM_PROXY_API_KEY")
|
||||
headers.update({"Authorization": f"Bearer {api_key}"})
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
api_base = api_base or get_secret_str("LITELLM_PROXY_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base not set for LiteLLM Proxy route. Set in env via `LITELLM_PROXY_API_BASE`"
|
||||
)
|
||||
api_base = api_base.rstrip("/")
|
||||
return f"{api_base}/images/generations"
|
||||
@@ -7305,6 +7305,12 @@ class ProviderConfigManager:
|
||||
)
|
||||
|
||||
return get_gemini_image_generation_config(model)
|
||||
elif LlmProviders.LITELLM_PROXY == provider:
|
||||
from litellm.llms.litellm_proxy.image_generation.transformation import (
|
||||
LiteLLMProxyImageGenerationConfig,
|
||||
)
|
||||
|
||||
return LiteLLMProxyImageGenerationConfig()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -7341,6 +7347,12 @@ class ProviderConfigManager:
|
||||
)
|
||||
|
||||
return RecraftImageEditConfig()
|
||||
elif LlmProviders.LITELLM_PROXY == provider:
|
||||
from litellm.llms.litellm_proxy.image_edit.transformation import (
|
||||
LiteLLMProxyImageEditConfig,
|
||||
)
|
||||
|
||||
return LiteLLMProxyImageEditConfig()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
@@ -184,6 +185,127 @@ async def test_litellm_gateway_from_sdk_image_generation(is_async):
|
||||
assert "dall-e-3" == mock_method.call_args.kwargs["model"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_async", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_litellm_gateway_image_generation_direct(is_async):
|
||||
"""Test image generation using the litellm_proxy provider directly."""
|
||||
litellm._turn_on_debug()
|
||||
|
||||
# Create mock response that matches OpenAI's response structure
|
||||
mock_openai_response = MagicMock()
|
||||
mock_openai_response.model_dump.return_value = {
|
||||
"created": 1,
|
||||
"data": [{"url": "https://example.com/image.png"}],
|
||||
}
|
||||
|
||||
if is_async:
|
||||
# Mock the AsyncOpenAI client that gets created inside _get_openai_client
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.images.generate = AsyncMock(return_value=mock_openai_response)
|
||||
|
||||
with patch("litellm.llms.openai.openai.AsyncOpenAI", return_value=mock_async_client) as mock_async_constructor:
|
||||
response = await litellm.aimage_generation(
|
||||
model="litellm_proxy/dall-e-3",
|
||||
prompt="A beautiful sunset over mountains",
|
||||
api_base="http://my-proxy",
|
||||
api_key="sk-1234",
|
||||
)
|
||||
|
||||
# Verify the AsyncOpenAI client constructor was called with correct parameters
|
||||
mock_async_constructor.assert_called_once()
|
||||
constructor_kwargs = mock_async_constructor.call_args.kwargs
|
||||
print("KWARGS to Async OpenAI constructor=", constructor_kwargs)
|
||||
assert constructor_kwargs["api_key"] == "sk-1234"
|
||||
assert constructor_kwargs["base_url"] == "http://my-proxy"
|
||||
|
||||
# Verify the AsyncOpenAI client was called correctly
|
||||
mock_async_client.images.generate.assert_awaited_once()
|
||||
call_kwargs = mock_async_client.images.generate.call_args.kwargs
|
||||
assert call_kwargs["model"] == "dall-e-3"
|
||||
assert call_kwargs["prompt"] == "A beautiful sunset over mountains"
|
||||
else:
|
||||
# Mock the sync OpenAI client that gets created inside _get_openai_client
|
||||
mock_sync_client = MagicMock()
|
||||
mock_sync_client.images.generate.return_value = mock_openai_response
|
||||
|
||||
with patch("litellm.llms.openai.openai.OpenAI", return_value=mock_sync_client) as mock_sync_constructor:
|
||||
response = litellm.image_generation(
|
||||
model="litellm_proxy/dall-e-3",
|
||||
prompt="A beautiful sunset over mountains",
|
||||
api_base="http://my-proxy",
|
||||
api_key="sk-1234",
|
||||
)
|
||||
|
||||
# Verify the OpenAI client constructor was called with correct parameters
|
||||
mock_sync_constructor.assert_called_once()
|
||||
constructor_kwargs = mock_sync_constructor.call_args.kwargs
|
||||
assert constructor_kwargs["api_key"] == "sk-1234"
|
||||
assert constructor_kwargs["base_url"] == "http://my-proxy"
|
||||
|
||||
# Verify the OpenAI client was called correctly
|
||||
mock_sync_client.images.generate.assert_called_once()
|
||||
call_kwargs = mock_sync_client.images.generate.call_args.kwargs
|
||||
assert call_kwargs["model"] == "dall-e-3"
|
||||
assert call_kwargs["prompt"] == "A beautiful sunset over mountains"
|
||||
|
||||
# Verify the response structure
|
||||
assert response is not None
|
||||
assert hasattr(response, 'data') or isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_async", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_litellm_gateway_from_sdk_image_edit(is_async):
|
||||
litellm._turn_on_debug()
|
||||
|
||||
mock_response = {
|
||||
"created": 1,
|
||||
"data": [{"b64_json": ""}],
|
||||
}
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data, status_code):
|
||||
self._json_data = json_data
|
||||
self.status_code = status_code
|
||||
self.text = json.dumps(json_data)
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
image_file = BytesIO(b"fake-image")
|
||||
|
||||
if is_async:
|
||||
mock_post = AsyncMock(return_value=MockResponse(mock_response, 200))
|
||||
patch_target = "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post"
|
||||
else:
|
||||
mock_post = MagicMock(return_value=MockResponse(mock_response, 200))
|
||||
patch_target = "litellm.llms.custom_httpx.http_handler.HTTPHandler.post"
|
||||
|
||||
with patch(patch_target, new=mock_post):
|
||||
if is_async:
|
||||
await litellm.aimage_edit(
|
||||
model="litellm_proxy/gpt-image-1",
|
||||
prompt="A test prompt",
|
||||
image=[image_file],
|
||||
api_base="http://my-proxy",
|
||||
api_key="sk-1234",
|
||||
)
|
||||
mock_post.assert_awaited_once()
|
||||
else:
|
||||
litellm.image_edit(
|
||||
model="litellm_proxy/gpt-image-1",
|
||||
prompt="A test prompt",
|
||||
image=[image_file],
|
||||
api_base="http://my-proxy",
|
||||
api_key="sk-1234",
|
||||
)
|
||||
mock_post.assert_called_once()
|
||||
|
||||
called_kwargs = mock_post.call_args.kwargs
|
||||
assert called_kwargs["url"] == "http://my-proxy/images/edits"
|
||||
assert called_kwargs["headers"]["Authorization"] == "Bearer sk-1234"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_async", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_litellm_gateway_from_sdk_transcription(is_async):
|
||||
|
||||
Reference in New Issue
Block a user