mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 00:48:01 +00:00
Add Vertex AI Image Edit Support (#16828)
* Add vertex ai image edit support * Fix lint errors
This commit is contained in:
@@ -16,7 +16,7 @@ LiteLLM provides image editing functionality that maps to OpenAI's `/images/edit
|
||||
| Supported operations | Create image edits | Single and multiple images supported |
|
||||
| Supported LiteLLM SDK Versions | 1.63.8+ | Gemini support requires 1.79.3+ |
|
||||
| Supported LiteLLM Proxy Versions | 1.71.1+ | Gemini support requires 1.79.3+ |
|
||||
| Supported LLM providers | **OpenAI**, **Gemini (Google AI Studio)** | Gemini supports the new `gemini-2.5-flash-image` family |
|
||||
| Supported LLM providers | **OpenAI**, **Gemini (Google AI Studio)**, **Vertex AI** | Gemini supports the new `gemini-2.5-flash-image` family. Vertex AI supports both Gemini and Imagen models. |
|
||||
|
||||
#### ⚡️See all supported models and providers at [models.litellm.ai](https://models.litellm.ai/)
|
||||
|
||||
@@ -197,6 +197,53 @@ for idx, image_obj in enumerate(response.data):
|
||||
f.write(base64.b64decode(image_obj.b64_json))
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="vertex_ai" label="Vertex AI">
|
||||
|
||||
#### Basic Image Edit (Gemini)
|
||||
```python showLineNumbers title="Vertex AI Gemini Image Edit"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
# Set Vertex AI credentials
|
||||
os.environ["VERTEXAI_PROJECT"] = "your-gcp-project-id"
|
||||
os.environ["VERTEXAI_LOCATION"] = "us-central1"
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/service-account.json"
|
||||
|
||||
response = litellm.image_edit(
|
||||
model="vertex_ai/gemini-2.5-flash",
|
||||
image=open("original_image.png", "rb"),
|
||||
prompt="Add neon lights in the background",
|
||||
size="1024x1024",
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
#### Image Edit with Imagen (Supports Masks)
|
||||
```python showLineNumbers title="Vertex AI Imagen Image Edit"
|
||||
import os
|
||||
import litellm
|
||||
|
||||
# Set Vertex AI credentials
|
||||
os.environ["VERTEXAI_PROJECT"] = "your-gcp-project-id"
|
||||
os.environ["VERTEXAI_LOCATION"] = "us-central1"
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/service-account.json"
|
||||
|
||||
# Imagen supports mask for inpainting
|
||||
response = litellm.image_edit(
|
||||
model="vertex_ai/imagen-3.0-capability-001",
|
||||
image=open("original_image.png", "rb"),
|
||||
mask=open("mask_image.png", "rb"), # Optional: for inpainting
|
||||
prompt="Turn this into watercolor style scenery",
|
||||
n=2, # Number of variations
|
||||
size="1024x1024",
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
@@ -302,6 +349,55 @@ curl -X POST "http://0.0.0.0:4000/v1/images/edits" \
|
||||
-F "size=1024x1024"
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="vertex_ai" label="Vertex AI">
|
||||
|
||||
1. Add Vertex AI image edit models to your `config.yaml`:
|
||||
```yaml showLineNumbers title="Vertex AI Proxy Configuration"
|
||||
model_list:
|
||||
- model_name: vertex-gemini-image-edit
|
||||
litellm_params:
|
||||
model: vertex_ai/gemini-2.5-flash
|
||||
vertex_project: os.environ/VERTEXAI_PROJECT
|
||||
vertex_location: os.environ/VERTEXAI_LOCATION
|
||||
vertex_credentials: os.environ/GOOGLE_APPLICATION_CREDENTIALS
|
||||
|
||||
- model_name: vertex-imagen-image-edit
|
||||
litellm_params:
|
||||
model: vertex_ai/imagen-3.0-capability-001
|
||||
vertex_project: os.environ/VERTEXAI_PROJECT
|
||||
vertex_location: os.environ/VERTEXAI_LOCATION
|
||||
vertex_credentials: os.environ/GOOGLE_APPLICATION_CREDENTIALS
|
||||
```
|
||||
|
||||
2. Start the LiteLLM proxy server:
|
||||
```bash showLineNumbers title="Start LiteLLM Proxy Server"
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Make an image edit request:
|
||||
```bash showLineNumbers title="Vertex AI Gemini Proxy Image Edit"
|
||||
curl -X POST "http://0.0.0.0:4000/v1/images/edits" \
|
||||
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
|
||||
-F "model=vertex-gemini-image-edit" \
|
||||
-F "image=@original_image.png" \
|
||||
-F "prompt=Add neon lights in the background" \
|
||||
-F "size=1024x1024"
|
||||
```
|
||||
|
||||
4. Imagen image edit with mask:
|
||||
```bash showLineNumbers title="Vertex AI Imagen Proxy Image Edit with Mask"
|
||||
curl -X POST "http://0.0.0.0:4000/v1/images/edits" \
|
||||
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
|
||||
-F "model=vertex-imagen-image-edit" \
|
||||
-F "image=@original_image.png" \
|
||||
-F "mask=@mask_image.png" \
|
||||
-F "prompt=Turn this into watercolor style scenery" \
|
||||
-F "n=2" \
|
||||
-F "size=1024x1024"
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.llms.vertex_ai.common_utils import VertexAIModelRoute, get_vertex_ai_model_route
|
||||
|
||||
from .cost_calculator import cost_calculator
|
||||
from .vertex_gemini_transformation import VertexAIGeminiImageEditConfig
|
||||
from .vertex_imagen_transformation import VertexAIImagenImageEditConfig
|
||||
|
||||
__all__ = [
|
||||
"VertexAIGeminiImageEditConfig",
|
||||
"VertexAIImagenImageEditConfig",
|
||||
"get_vertex_ai_image_edit_config",
|
||||
"cost_calculator"
|
||||
]
|
||||
|
||||
|
||||
def get_vertex_ai_image_edit_config(model: str) -> BaseImageEditConfig:
|
||||
"""
|
||||
Get the appropriate image edit config for a Vertex AI model.
|
||||
|
||||
Routes to the correct transformation class based on the model type:
|
||||
- Gemini models use generateContent API (VertexAIGeminiImageEditConfig)
|
||||
- Imagen models use predict API (VertexAIImagenImageEditConfig)
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "gemini-2.5-flash", "imagegeneration@006")
|
||||
|
||||
Returns:
|
||||
BaseImageEditConfig: The appropriate configuration class
|
||||
"""
|
||||
# Determine the model route
|
||||
model_route = get_vertex_ai_model_route(model)
|
||||
|
||||
if model_route == VertexAIModelRoute.GEMINI:
|
||||
# Gemini models use generateContent API
|
||||
return VertexAIGeminiImageEditConfig()
|
||||
else:
|
||||
# Default to Imagen for other models (imagegeneration, etc.)
|
||||
# This includes NON_GEMINI models like imagegeneration@006
|
||||
return VertexAIImagenImageEditConfig()
|
||||
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Vertex AI Image Edit Cost Calculator
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Vertex AI image edit cost calculator.
|
||||
|
||||
Mirrors image generation pricing: charge per returned image based on
|
||||
model metadata (`output_cost_per_image`).
|
||||
"""
|
||||
model_info = litellm.get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
|
||||
|
||||
if not isinstance(image_response, ImageResponse):
|
||||
raise ValueError(
|
||||
f"image_response must be of type ImageResponse got type={type(image_response)}"
|
||||
)
|
||||
|
||||
num_images = len(image_response.data or [])
|
||||
return output_cost_per_image * num_images
|
||||
@@ -0,0 +1,263 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from io import BufferedReader, BytesIO
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
import litellm
|
||||
|
||||
from litellm.images.utils import ImageEditRequestUtils
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageObject, ImageResponse, OpenAIImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexAIGeminiImageEditConfig(BaseImageEditConfig, VertexLLM):
|
||||
"""
|
||||
Vertex AI Gemini Image Edit Configuration
|
||||
|
||||
Uses generateContent API for Gemini models on Vertex AI
|
||||
"""
|
||||
SUPPORTED_PARAMS: List[str] = ["size"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseImageEditConfig.__init__(self)
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return list(self.SUPPORTED_PARAMS)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict[str, Any]:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
filtered_params = {
|
||||
key: value
|
||||
for key, value in image_edit_optional_params.items()
|
||||
if key in supported_params
|
||||
}
|
||||
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
|
||||
if "size" in filtered_params:
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(
|
||||
filtered_params["size"] # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _resolve_vertex_project(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_project", None)
|
||||
or os.environ.get("VERTEXAI_PROJECT")
|
||||
or getattr(litellm, "vertex_project", None)
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
def _resolve_vertex_location(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_location", None)
|
||||
or os.environ.get("VERTEXAI_LOCATION")
|
||||
or os.environ.get("VERTEX_LOCATION")
|
||||
or getattr(litellm, "vertex_location", None)
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
|
||||
def _resolve_vertex_credentials(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_credentials", None)
|
||||
or os.environ.get("VERTEXAI_CREDENTIALS")
|
||||
or getattr(litellm, "vertex_credentials", None)
|
||||
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = headers or {}
|
||||
vertex_project = self._resolve_vertex_project()
|
||||
vertex_credentials = self._resolve_vertex_credentials()
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
return self.set_headers(access_token, headers)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Vertex AI Gemini generateContent API
|
||||
"""
|
||||
vertex_project = self._resolve_vertex_project()
|
||||
vertex_location = self._resolve_vertex_location()
|
||||
|
||||
if not vertex_project or not vertex_location:
|
||||
raise ValueError("vertex_project and vertex_location are required for Vertex AI")
|
||||
|
||||
# Use the model name as provided, handling vertex_ai prefix
|
||||
model_name = model
|
||||
if model.startswith("vertex_ai/"):
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
|
||||
if api_base:
|
||||
base_url = api_base.rstrip("/")
|
||||
else:
|
||||
base_url = f"https://{vertex_location}-aiplatform.googleapis.com"
|
||||
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:generateContent"
|
||||
|
||||
def transform_image_edit_request( # type: ignore[override]
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: FileTypes,
|
||||
image_edit_optional_request_params: Dict[str, Any],
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict[str, Any], Optional[RequestFiles]]:
|
||||
inline_parts = self._prepare_inline_image_parts(image)
|
||||
if not inline_parts:
|
||||
raise ValueError("Vertex AI Gemini image edit requires at least one image.")
|
||||
|
||||
# Correct format for Vertex AI Gemini image editing
|
||||
contents = {
|
||||
"role": "USER",
|
||||
"parts": inline_parts + [{"text": prompt}]
|
||||
}
|
||||
|
||||
request_body: Dict[str, Any] = {"contents": contents}
|
||||
|
||||
# Generation config with proper structure for image editing
|
||||
generation_config: Dict[str, Any] = {
|
||||
"response_modalities": ["IMAGE"]
|
||||
}
|
||||
|
||||
# Add image-specific configuration
|
||||
image_config: Dict[str, Any] = {}
|
||||
if "aspectRatio" in image_edit_optional_request_params:
|
||||
image_config["aspect_ratio"] = image_edit_optional_request_params["aspectRatio"]
|
||||
|
||||
if image_config:
|
||||
generation_config["image_config"] = image_config
|
||||
|
||||
request_body["generationConfig"] = generation_config
|
||||
|
||||
payload: Any = json.dumps(request_body)
|
||||
empty_files = cast(RequestFiles, [])
|
||||
return cast(Tuple[Dict[str, Any], Optional[RequestFiles]], (payload, empty_files))
|
||||
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: Any,
|
||||
) -> ImageResponse:
|
||||
model_response = ImageResponse()
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception as exc:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image edit response: {exc}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
candidates = response_json.get("candidates", [])
|
||||
data_list: List[ImageObject] = []
|
||||
|
||||
for candidate in candidates:
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
inline_data = part.get("inlineData")
|
||||
if inline_data and inline_data.get("data"):
|
||||
data_list.append(
|
||||
ImageObject(
|
||||
b64_json=inline_data["data"],
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.data = cast(List[OpenAIImage], data_list)
|
||||
return model_response
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
"""Map OpenAI size format to Gemini aspect ratio format"""
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _prepare_inline_image_parts(
|
||||
self, image: Union[FileTypes, List[FileTypes]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
images: List[FileTypes]
|
||||
if isinstance(image, list):
|
||||
images = image
|
||||
else:
|
||||
images = [image]
|
||||
|
||||
inline_parts: List[Dict[str, Any]] = []
|
||||
for img in images:
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
mime_type = ImageEditRequestUtils.get_image_content_type(img)
|
||||
image_bytes = self._read_all_bytes(img)
|
||||
inline_parts.append(
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": mime_type,
|
||||
"data": base64.b64encode(image_bytes).decode("utf-8"),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return inline_parts
|
||||
|
||||
def _read_all_bytes(self, image: FileTypes) -> bytes:
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
if isinstance(image, BytesIO):
|
||||
current_pos = image.tell()
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
image.seek(current_pos)
|
||||
return data
|
||||
if isinstance(image, BufferedReader):
|
||||
current_pos = image.tell()
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
image.seek(current_pos)
|
||||
return data
|
||||
raise ValueError("Unsupported image type for Vertex AI Gemini image edit.")
|
||||
@@ -0,0 +1,345 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from io import BufferedRandom, BufferedReader, BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
import litellm
|
||||
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageObject, ImageResponse, OpenAIImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexAIImagenImageEditConfig(BaseImageEditConfig, VertexLLM):
|
||||
"""
|
||||
Vertex AI Imagen Image Edit Configuration
|
||||
|
||||
Uses predict API for Imagen models on Vertex AI
|
||||
"""
|
||||
SUPPORTED_PARAMS: List[str] = ["n", "size", "mask"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseImageEditConfig.__init__(self)
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return list(self.SUPPORTED_PARAMS)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict[str, Any]:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
filtered_params = {
|
||||
key: value
|
||||
for key, value in image_edit_optional_params.items()
|
||||
if key in supported_params
|
||||
}
|
||||
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
|
||||
# Map OpenAI parameters to Imagen format
|
||||
if "n" in filtered_params:
|
||||
mapped_params["sampleCount"] = filtered_params["n"]
|
||||
|
||||
if "size" in filtered_params:
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(
|
||||
filtered_params["size"] # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if "mask" in filtered_params:
|
||||
mapped_params["mask"] = filtered_params["mask"]
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _resolve_vertex_project(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_project", None)
|
||||
or os.environ.get("VERTEXAI_PROJECT")
|
||||
or getattr(litellm, "vertex_project", None)
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
|
||||
def _resolve_vertex_location(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_location", None)
|
||||
or os.environ.get("VERTEXAI_LOCATION")
|
||||
or os.environ.get("VERTEX_LOCATION")
|
||||
or getattr(litellm, "vertex_location", None)
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEX_LOCATION")
|
||||
)
|
||||
|
||||
def _resolve_vertex_credentials(self) -> Optional[str]:
|
||||
return (
|
||||
getattr(self, "_vertex_credentials", None)
|
||||
or os.environ.get("VERTEXAI_CREDENTIALS")
|
||||
or getattr(litellm, "vertex_credentials", None)
|
||||
or os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = headers or {}
|
||||
vertex_project = self._resolve_vertex_project()
|
||||
vertex_credentials = self._resolve_vertex_credentials()
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
return self.set_headers(access_token, headers)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Vertex AI Imagen predict API
|
||||
"""
|
||||
vertex_project = self._resolve_vertex_project()
|
||||
vertex_location = self._resolve_vertex_location()
|
||||
|
||||
if not vertex_project or not vertex_location:
|
||||
raise ValueError("vertex_project and vertex_location are required for Vertex AI")
|
||||
|
||||
# Use the model name as provided, handling vertex_ai prefix
|
||||
model_name = model
|
||||
if model.startswith("vertex_ai/"):
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
|
||||
if api_base:
|
||||
base_url = api_base.rstrip("/")
|
||||
else:
|
||||
base_url = f"https://{vertex_location}-aiplatform.googleapis.com"
|
||||
|
||||
return f"{base_url}/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model_name}:predict"
|
||||
|
||||
def transform_image_edit_request( # type: ignore[override]
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: FileTypes,
|
||||
image_edit_optional_request_params: Dict[str, Any],
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict[str, Any], Optional[RequestFiles]]:
|
||||
# Prepare reference images in the correct Imagen format
|
||||
reference_images = self._prepare_reference_images(image, image_edit_optional_request_params)
|
||||
if not reference_images:
|
||||
raise ValueError("Vertex AI Imagen image edit requires at least one reference image.")
|
||||
|
||||
# Correct Imagen instances format
|
||||
instances = [
|
||||
{
|
||||
"prompt": prompt,
|
||||
"referenceImages": reference_images
|
||||
}
|
||||
]
|
||||
|
||||
# Extract OpenAI parameters and set sensible defaults for Vertex AI-specific parameters
|
||||
sample_count = image_edit_optional_request_params.get("sampleCount", 1)
|
||||
# Use sensible defaults for Vertex AI-specific parameters (not exposed to users)
|
||||
edit_mode = "EDIT_MODE_INPAINT_INSERTION" # Default edit mode
|
||||
base_steps = 50 # Default number of steps
|
||||
|
||||
# Imagen parameters with correct structure
|
||||
parameters = {
|
||||
"sampleCount": sample_count,
|
||||
"editMode": edit_mode,
|
||||
"editConfig": {
|
||||
"baseSteps": base_steps
|
||||
}
|
||||
}
|
||||
|
||||
# Set default values for Vertex AI-specific parameters (not configurable by users via OpenAI API)
|
||||
parameters["guidanceScale"] = 7.5 # Default guidance scale
|
||||
parameters["seed"] = None # Let Vertex AI choose random seed
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"instances": instances,
|
||||
"parameters": parameters
|
||||
}
|
||||
|
||||
payload: Any = json.dumps(request_body)
|
||||
empty_files = cast(RequestFiles, [])
|
||||
return cast(Tuple[Dict[str, Any], Optional[RequestFiles]], (payload, empty_files))
|
||||
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: Any,
|
||||
) -> ImageResponse:
|
||||
model_response = ImageResponse()
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception as exc:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image edit response: {exc}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
predictions = response_json.get("predictions", [])
|
||||
data_list: List[ImageObject] = []
|
||||
|
||||
for prediction in predictions:
|
||||
# Imagen returns images as bytesBase64Encoded
|
||||
if "bytesBase64Encoded" in prediction:
|
||||
data_list.append(
|
||||
ImageObject(
|
||||
b64_json=prediction["bytesBase64Encoded"],
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.data = cast(List[OpenAIImage], data_list)
|
||||
return model_response
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
"""Map OpenAI size format to Imagen aspect ratio format"""
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _prepare_reference_images(
|
||||
self, image: Union[FileTypes, List[FileTypes]],
|
||||
image_edit_optional_request_params: Dict[str, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Prepare reference images in the correct Imagen API format
|
||||
"""
|
||||
images: List[FileTypes]
|
||||
if isinstance(image, list):
|
||||
images = image
|
||||
else:
|
||||
images = [image]
|
||||
|
||||
reference_images: List[Dict[str, Any]] = []
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
image_bytes = self._read_all_bytes(img)
|
||||
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
# Create reference image structure
|
||||
reference_image = {
|
||||
"referenceType": "REFERENCE_TYPE_RAW",
|
||||
"referenceId": idx + 1,
|
||||
"referenceImage": {
|
||||
"bytesBase64Encoded": base64_data
|
||||
}
|
||||
}
|
||||
|
||||
reference_images.append(reference_image)
|
||||
|
||||
# Handle mask image if provided (for inpainting)
|
||||
mask_image = image_edit_optional_request_params.get("mask")
|
||||
if mask_image is not None:
|
||||
mask_bytes = self._read_all_bytes(mask_image)
|
||||
mask_base64 = base64.b64encode(mask_bytes).decode("utf-8")
|
||||
|
||||
mask_reference = {
|
||||
"referenceType": "REFERENCE_TYPE_MASK",
|
||||
"referenceId": len(reference_images) + 1,
|
||||
"referenceImage": {
|
||||
"bytesBase64Encoded": mask_base64
|
||||
},
|
||||
"maskImageConfig": {
|
||||
"maskMode": "MASK_MODE_USER_PROVIDED",
|
||||
"dilation": 0.03 # Default dilation value (not configurable via OpenAI API)
|
||||
}
|
||||
}
|
||||
reference_images.append(mask_reference)
|
||||
|
||||
return reference_images
|
||||
|
||||
def _read_all_bytes(self, image: Any) -> bytes:
|
||||
if isinstance(image, (list, tuple)):
|
||||
for item in image:
|
||||
if item is not None:
|
||||
return self._read_all_bytes(item)
|
||||
raise ValueError("Unsupported image type for Vertex AI Imagen image edit.")
|
||||
|
||||
if isinstance(image, dict):
|
||||
for key in ("data", "bytes", "content"):
|
||||
if key in image and image[key] is not None:
|
||||
value = image[key]
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return base64.b64decode(value)
|
||||
except Exception:
|
||||
continue
|
||||
return self._read_all_bytes(value)
|
||||
if "path" in image:
|
||||
return self._read_all_bytes(image["path"])
|
||||
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
if isinstance(image, bytearray):
|
||||
return bytes(image)
|
||||
if isinstance(image, BytesIO):
|
||||
current_pos = image.tell()
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
image.seek(current_pos)
|
||||
return data
|
||||
if isinstance(image, (BufferedReader, BufferedRandom)):
|
||||
stream_pos: Optional[int] = None
|
||||
try:
|
||||
stream_pos = image.tell()
|
||||
except Exception:
|
||||
stream_pos = None
|
||||
if stream_pos is not None:
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
if stream_pos is not None:
|
||||
image.seek(stream_pos)
|
||||
return data
|
||||
if isinstance(image, (str, Path)):
|
||||
path_obj = Path(image)
|
||||
if not path_obj.exists():
|
||||
raise ValueError(
|
||||
f"Mask/image path does not exist for Vertex AI Imagen image edit: {path_obj}"
|
||||
)
|
||||
return path_obj.read_bytes()
|
||||
if hasattr(image, "read"):
|
||||
data = image.read()
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
return data
|
||||
raise ValueError(
|
||||
f"Unsupported image type for Vertex AI Imagen image edit. Got type={type(image)}"
|
||||
)
|
||||
@@ -24434,6 +24434,12 @@
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"vertex_ai/gemini-2.5-flash-image": {
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "image_generation",
|
||||
"output_cost_per_image": 0.039,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-generation#edit-an-image"
|
||||
},
|
||||
"vertex_ai/imagegeneration@006": {
|
||||
"litellm_provider": "vertex_ai-image-models",
|
||||
"mode": "image_generation",
|
||||
@@ -24458,6 +24464,12 @@
|
||||
"output_cost_per_image": 0.04,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||
},
|
||||
"vertex_ai/imagen-3.0-capability-001": {
|
||||
"litellm_provider": "vertex_ai-image-models",
|
||||
"mode": "image_generation",
|
||||
"output_cost_per_image": 0.04,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/image/edit-insert-objects"
|
||||
},
|
||||
"vertex_ai/imagen-4.0-fast-generate-001": {
|
||||
"litellm_provider": "vertex_ai-image-models",
|
||||
"mode": "image_generation",
|
||||
|
||||
@@ -7757,6 +7757,10 @@ class ProviderConfigManager:
|
||||
)
|
||||
|
||||
return LiteLLMProxyImageEditConfig()
|
||||
elif LlmProviders.VERTEX_AI == provider:
|
||||
from litellm.llms.vertex_ai.image_edit import get_vertex_ai_image_edit_config
|
||||
|
||||
return get_vertex_ai_image_edit_config(model)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -24434,6 +24434,12 @@
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"vertex_ai/gemini-2.5-flash-image": {
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "image_generation",
|
||||
"output_cost_per_image": 0.039,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-generation#edit-an-image"
|
||||
},
|
||||
"vertex_ai/imagegeneration@006": {
|
||||
"litellm_provider": "vertex_ai-image-models",
|
||||
"mode": "image_generation",
|
||||
@@ -24458,6 +24464,12 @@
|
||||
"output_cost_per_image": 0.04,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||
},
|
||||
"vertex_ai/imagen-3.0-capability-001": {
|
||||
"litellm_provider": "vertex_ai-image-models",
|
||||
"mode": "image_generation",
|
||||
"output_cost_per_image": 0.04,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/image/edit-insert-objects"
|
||||
},
|
||||
"vertex_ai/imagen-4.0-fast-generate-001": {
|
||||
"litellm_provider": "vertex_ai-image-models",
|
||||
"mode": "image_generation",
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
# Vertex AI Image Edit Tests
|
||||
|
||||
+321
@@ -0,0 +1,321 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import Dict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from litellm.llms.vertex_ai.image_edit.vertex_gemini_transformation import (
|
||||
VertexAIGeminiImageEditConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai.image_edit.vertex_imagen_transformation import (
|
||||
VertexAIImagenImageEditConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestVertexAIGeminiImageEditTransformation:
|
||||
def setup_method(self) -> None:
|
||||
self.config = VertexAIGeminiImageEditConfig()
|
||||
self.model = "vertex_ai/gemini-2.5-flash"
|
||||
self.prompt = "Add neon lights in the background"
|
||||
self.logging_obj = MagicMock()
|
||||
|
||||
def test_map_openai_params(self) -> None:
|
||||
"""Test mapping OpenAI parameters to Vertex AI Gemini format"""
|
||||
optional_params: Dict[str, object] = {
|
||||
"size": "1792x1024",
|
||||
}
|
||||
|
||||
mapped = self.config.map_openai_params(
|
||||
image_edit_optional_params=optional_params, # type: ignore[arg-type]
|
||||
model=self.model,
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
assert mapped["aspectRatio"] == "16:9"
|
||||
|
||||
def test_get_complete_url(self) -> None:
|
||||
"""Test URL generation for Vertex AI Gemini"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VERTEXAI_PROJECT": "test-project",
|
||||
"VERTEXAI_LOCATION": "us-central1",
|
||||
},
|
||||
):
|
||||
url = self.config.get_complete_url(
|
||||
model="gemini-2.5-flash",
|
||||
api_base=None,
|
||||
litellm_params={},
|
||||
)
|
||||
assert "test-project" in url
|
||||
assert "us-central1" in url
|
||||
assert "generateContent" in url
|
||||
|
||||
def test_transform_image_edit_request(self) -> None:
|
||||
"""Test request transformation for Vertex AI Gemini"""
|
||||
image_bytes = b"fake_image_data"
|
||||
image = BytesIO(image_bytes)
|
||||
optional_params = {
|
||||
"aspectRatio": "1:1",
|
||||
}
|
||||
|
||||
request_body_str, files = self.config.transform_image_edit_request(
|
||||
model=self.model,
|
||||
prompt=self.prompt,
|
||||
image=image,
|
||||
image_edit_optional_request_params=optional_params,
|
||||
litellm_params=MagicMock(),
|
||||
headers={},
|
||||
)
|
||||
|
||||
assert files == []
|
||||
assert isinstance(request_body_str, str)
|
||||
|
||||
request_body = json.loads(request_body_str)
|
||||
assert "contents" in request_body
|
||||
assert request_body["contents"]["role"] == "USER"
|
||||
|
||||
parts = request_body["contents"]["parts"]
|
||||
assert parts[-1]["text"] == self.prompt
|
||||
|
||||
inline_data = parts[0]["inlineData"]
|
||||
assert inline_data["mimeType"] == "image/png"
|
||||
assert base64.b64decode(inline_data["data"]) == image_bytes
|
||||
|
||||
generation_config = request_body["generationConfig"]
|
||||
assert generation_config["response_modalities"] == ["IMAGE"]
|
||||
assert generation_config["image_config"]["aspect_ratio"] == "1:1"
|
||||
|
||||
def test_transform_image_edit_response(self) -> None:
|
||||
"""Test response transformation for Vertex AI Gemini"""
|
||||
response_payload = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": base64.b64encode(b"image-one").decode("utf-8"),
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.json.return_value = response_payload
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {}
|
||||
|
||||
image_response = self.config.transform_image_edit_response(
|
||||
model=self.model,
|
||||
raw_response=mock_response,
|
||||
logging_obj=self.logging_obj,
|
||||
)
|
||||
|
||||
assert image_response.data is not None
|
||||
assert len(image_response.data) == 1
|
||||
assert image_response.data[0].b64_json == base64.b64encode(b"image-one").decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
def test_transform_image_edit_request_without_image_raises(self) -> None:
|
||||
"""Test that missing image raises ValueError"""
|
||||
optional_params = {}
|
||||
|
||||
with pytest.raises(ValueError, match="requires at least one image"):
|
||||
self.config.transform_image_edit_request(
|
||||
model=self.model,
|
||||
prompt=self.prompt,
|
||||
image=[],
|
||||
image_edit_optional_request_params=optional_params,
|
||||
litellm_params=MagicMock(),
|
||||
headers={},
|
||||
)
|
||||
|
||||
|
||||
class TestVertexAIImagenImageEditTransformation:
|
||||
def setup_method(self) -> None:
|
||||
self.config = VertexAIImagenImageEditConfig()
|
||||
self.model = "vertex_ai/imagen-3.0-capability-001"
|
||||
self.prompt = "Turn this into watercolor style scenery"
|
||||
self.logging_obj = MagicMock()
|
||||
|
||||
def test_map_openai_params(self) -> None:
|
||||
"""Test mapping OpenAI parameters to Vertex AI Imagen format"""
|
||||
optional_params: Dict[str, object] = {
|
||||
"n": 2,
|
||||
"size": "1024x1024",
|
||||
"mask": BytesIO(b"mask_data"),
|
||||
}
|
||||
|
||||
mapped = self.config.map_openai_params(
|
||||
image_edit_optional_params=optional_params, # type: ignore[arg-type]
|
||||
model=self.model,
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
assert mapped["sampleCount"] == 2
|
||||
assert mapped["aspectRatio"] == "1:1"
|
||||
assert "mask" in mapped
|
||||
|
||||
def test_get_complete_url(self) -> None:
|
||||
"""Test URL generation for Vertex AI Imagen"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VERTEXAI_PROJECT": "test-project",
|
||||
"VERTEXAI_LOCATION": "us-central1",
|
||||
},
|
||||
):
|
||||
url = self.config.get_complete_url(
|
||||
model="imagen-3.0-capability-001",
|
||||
api_base=None,
|
||||
litellm_params={},
|
||||
)
|
||||
assert "test-project" in url
|
||||
assert "us-central1" in url
|
||||
assert "predict" in url
|
||||
|
||||
def test_transform_image_edit_request(self) -> None:
|
||||
"""Test request transformation for Vertex AI Imagen"""
|
||||
image_bytes = b"fake_image_data"
|
||||
image = BytesIO(image_bytes)
|
||||
optional_params = {
|
||||
"sampleCount": 1,
|
||||
}
|
||||
|
||||
request_body_str, files = self.config.transform_image_edit_request(
|
||||
model=self.model,
|
||||
prompt=self.prompt,
|
||||
image=image,
|
||||
image_edit_optional_request_params=optional_params,
|
||||
litellm_params=MagicMock(),
|
||||
headers={},
|
||||
)
|
||||
|
||||
assert files == []
|
||||
assert isinstance(request_body_str, str)
|
||||
|
||||
request_body = json.loads(request_body_str)
|
||||
assert "instances" in request_body
|
||||
assert "parameters" in request_body
|
||||
|
||||
instance = request_body["instances"][0]
|
||||
assert instance["prompt"] == self.prompt
|
||||
assert "referenceImages" in instance
|
||||
|
||||
reference_image = instance["referenceImages"][0]
|
||||
assert reference_image["referenceType"] == "REFERENCE_TYPE_RAW"
|
||||
assert reference_image["referenceId"] == 1
|
||||
assert "referenceImage" in reference_image
|
||||
assert "bytesBase64Encoded" in reference_image["referenceImage"]
|
||||
|
||||
parameters = request_body["parameters"]
|
||||
assert parameters["sampleCount"] == 1
|
||||
assert parameters["editMode"] == "EDIT_MODE_INPAINT_INSERTION"
|
||||
assert "editConfig" in parameters
|
||||
|
||||
def test_transform_image_edit_request_with_mask(self) -> None:
|
||||
"""Test request transformation with mask for inpainting"""
|
||||
image_bytes = b"fake_image_data"
|
||||
mask_bytes = b"mask_data"
|
||||
image = BytesIO(image_bytes)
|
||||
mask = BytesIO(mask_bytes)
|
||||
optional_params = {
|
||||
"sampleCount": 2,
|
||||
"mask": mask,
|
||||
}
|
||||
|
||||
request_body_str, files = self.config.transform_image_edit_request(
|
||||
model=self.model,
|
||||
prompt=self.prompt,
|
||||
image=image,
|
||||
image_edit_optional_request_params=optional_params,
|
||||
litellm_params=MagicMock(),
|
||||
headers={},
|
||||
)
|
||||
|
||||
request_body = json.loads(request_body_str)
|
||||
reference_images = request_body["instances"][0]["referenceImages"]
|
||||
|
||||
# Should have both base image and mask
|
||||
assert len(reference_images) == 2
|
||||
|
||||
# First should be RAW reference
|
||||
assert reference_images[0]["referenceType"] == "REFERENCE_TYPE_RAW"
|
||||
assert reference_images[0]["referenceId"] == 1
|
||||
|
||||
# Second should be MASK reference
|
||||
assert reference_images[1]["referenceType"] == "REFERENCE_TYPE_MASK"
|
||||
assert "maskImageConfig" in reference_images[1]
|
||||
assert reference_images[1]["maskImageConfig"]["maskMode"] == "MASK_MODE_USER_PROVIDED"
|
||||
|
||||
def test_transform_image_edit_response(self) -> None:
|
||||
"""Test response transformation for Vertex AI Imagen"""
|
||||
response_payload = {
|
||||
"predictions": [
|
||||
{
|
||||
"bytesBase64Encoded": base64.b64encode(b"image-one").decode("utf-8"),
|
||||
"mimeType": "image/png",
|
||||
},
|
||||
{
|
||||
"bytesBase64Encoded": base64.b64encode(b"image-two").decode("utf-8"),
|
||||
"mimeType": "image/png",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.json.return_value = response_payload
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {}
|
||||
|
||||
image_response = self.config.transform_image_edit_response(
|
||||
model=self.model,
|
||||
raw_response=mock_response,
|
||||
logging_obj=self.logging_obj,
|
||||
)
|
||||
|
||||
assert image_response.data is not None
|
||||
assert len(image_response.data) == 2
|
||||
assert image_response.data[0].b64_json == base64.b64encode(b"image-one").decode(
|
||||
"utf-8"
|
||||
)
|
||||
assert image_response.data[1].b64_json == base64.b64encode(b"image-two").decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
def test_transform_image_edit_request_without_image_raises(self) -> None:
|
||||
"""Test that missing image raises ValueError"""
|
||||
optional_params = {}
|
||||
|
||||
with pytest.raises(ValueError, match="requires at least one reference image"):
|
||||
self.config.transform_image_edit_request(
|
||||
model=self.model,
|
||||
prompt=self.prompt,
|
||||
image=[],
|
||||
image_edit_optional_request_params=optional_params,
|
||||
litellm_params=MagicMock(),
|
||||
headers={},
|
||||
)
|
||||
|
||||
def test_read_all_bytes_handles_various_types(self) -> None:
|
||||
"""Test that _read_all_bytes handles different file types"""
|
||||
# Test with bytes
|
||||
assert self.config._read_all_bytes(b"test_bytes") == b"test_bytes"
|
||||
|
||||
# Test with BytesIO
|
||||
bio = BytesIO(b"test_bytesio")
|
||||
assert self.config._read_all_bytes(bio) == b"test_bytesio"
|
||||
|
||||
# Test with bytearray
|
||||
assert self.config._read_all_bytes(bytearray(b"test_bytearray")) == b"test_bytearray"
|
||||
|
||||
Reference in New Issue
Block a user