From 883cfaeeafa2ecb16f91a540be63dfbfa644001c Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 25 Nov 2025 13:32:13 +0530 Subject: [PATCH] Add tests --- .../image_gen_tests/test_image_generation.py | 32 ++ ...rtex_ai_image_generation_transformation.py | 457 ++++++++++++++++++ 2 files changed, 489 insertions(+) create mode 100644 tests/test_litellm/llms/vertex_ai/image_generation/test_vertex_ai_image_generation_transformation.py diff --git a/tests/image_gen_tests/test_image_generation.py b/tests/image_gen_tests/test_image_generation.py index 24d5843293..add60c755b 100644 --- a/tests/image_gen_tests/test_image_generation.py +++ b/tests/image_gen_tests/test_image_generation.py @@ -119,6 +119,38 @@ class TestVertexImageGeneration(BaseImageGenTest): } +class TestVertexAIGeminiImageGeneration(BaseImageGenTest): + """Test Gemini image generation models (Nano Banana)""" + def get_base_image_generation_call_args(self) -> dict: + # comment this when running locally + load_vertex_ai_credentials() + + litellm.in_memory_llm_clients_cache = InMemoryCache() + return { + "model": "vertex_ai/gemini-2.5-flash-image", + "vertex_ai_project": "pathrise-convert-1606954137718", + "vertex_ai_location": "us-central1", + "n": 1, + "size": "1024x1024", + } + + +class TestVertexAIGemini3ProImageGeneration(BaseImageGenTest): + """Test Gemini 3 Pro image generation model""" + def get_base_image_generation_call_args(self) -> dict: + # comment this when running locally + load_vertex_ai_credentials() + + litellm.in_memory_llm_clients_cache = InMemoryCache() + return { + "model": "vertex_ai/gemini-3-pro-image-preview", + "vertex_ai_project": "pathrise-convert-1606954137718", + "vertex_ai_location": "us-central1", + "n": 1, + "size": "1024x1024", + } + + class TestBedrockNovaCanvasTextToImage(BaseImageGenTest): def get_base_image_generation_call_args(self) -> dict: litellm.in_memory_llm_clients_cache = InMemoryCache() diff --git a/tests/test_litellm/llms/vertex_ai/image_generation/test_vertex_ai_image_generation_transformation.py b/tests/test_litellm/llms/vertex_ai/image_generation/test_vertex_ai_image_generation_transformation.py new file mode 100644 index 0000000000..9f33400594 --- /dev/null +++ b/tests/test_litellm/llms/vertex_ai/image_generation/test_vertex_ai_image_generation_transformation.py @@ -0,0 +1,457 @@ +import os +import sys +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +sys.path.insert(0, os.path.abspath("../..")) + +from litellm.llms.vertex_ai.image_generation import ( + get_vertex_ai_image_generation_config, +) +from litellm.llms.vertex_ai.image_generation.vertex_gemini_transformation import ( + VertexAIGeminiImageGenerationConfig, +) +from litellm.llms.vertex_ai.image_generation.vertex_imagen_transformation import ( + VertexAIImagenImageGenerationConfig, +) + + +class TestVertexAIGeminiImageGenerationConfig: + def setup_method(self): + """Set up test fixtures""" + self.config = VertexAIGeminiImageGenerationConfig() + + def test_get_supported_openai_params(self): + """Test get_supported_openai_params returns correct params""" + supported = self.config.get_supported_openai_params("gemini-2.5-flash-image") + assert "n" in supported + assert "size" in supported + + def test_map_openai_params_n(self): + """Test mapping n parameter to candidate_count""" + non_default_params = {"n": 3} + optional_params = {} + result = self.config.map_openai_params( + non_default_params, optional_params, "gemini-2.5-flash-image", False + ) + assert result.get("candidate_count") == 3 + + def test_map_openai_params_size(self): + """Test mapping size parameter to aspectRatio""" + non_default_params = {"size": "1024x1024"} + optional_params = {} + result = self.config.map_openai_params( + non_default_params, optional_params, "gemini-2.5-flash-image", False + ) + assert result.get("aspectRatio") == "1:1" + + def test_map_openai_params_size_16_9(self): + """Test mapping 16:9 size""" + non_default_params = {"size": "1792x1024"} + optional_params = {} + result = self.config.map_openai_params( + non_default_params, optional_params, "gemini-2.5-flash-image", False + ) + assert result.get("aspectRatio") == "16:9" + + def test_map_size_to_aspect_ratio(self): + """Test size to aspect ratio mapping""" + assert self.config._map_size_to_aspect_ratio("1024x1024") == "1:1" + assert self.config._map_size_to_aspect_ratio("1792x1024") == "16:9" + assert self.config._map_size_to_aspect_ratio("1024x1792") == "9:16" + assert self.config._map_size_to_aspect_ratio("1280x896") == "4:3" + assert self.config._map_size_to_aspect_ratio("896x1280") == "3:4" + assert self.config._map_size_to_aspect_ratio("unknown") == "1:1" # default + + def test_transform_image_generation_request_basic(self): + """Test basic request transformation""" + request = self.config.transform_image_generation_request( + model="gemini-2.5-flash-image", + prompt="A nano banana", + optional_params={}, + litellm_params={}, + headers={}, + ) + assert "contents" in request + assert "generationConfig" in request + assert request["generationConfig"]["responseModalities"] == ["IMAGE"] + assert request["contents"][0]["parts"][0]["text"] == "A nano banana" + + def test_transform_image_generation_request_with_aspect_ratio(self): + """Test request transformation with aspectRatio""" + request = self.config.transform_image_generation_request( + model="gemini-2.5-flash-image", + prompt="A nano banana", + optional_params={"aspectRatio": "16:9"}, + litellm_params={}, + headers={}, + ) + assert request["generationConfig"]["imageConfig"]["aspectRatio"] == "16:9" + + def test_transform_image_generation_request_with_image_size(self): + """Test request transformation with imageSize (Gemini 3 Pro)""" + request = self.config.transform_image_generation_request( + model="gemini-3-pro-image-preview", + prompt="A nano banana", + optional_params={"imageSize": "4K"}, + litellm_params={}, + headers={}, + ) + assert request["generationConfig"]["imageConfig"]["imageSize"] == "4K" + + def test_transform_image_generation_request_with_candidate_count(self): + """Test request transformation with candidate_count""" + request = self.config.transform_image_generation_request( + model="gemini-2.5-flash-image", + prompt="A nano banana", + optional_params={"candidate_count": 2}, + litellm_params={}, + headers={}, + ) + assert request["generationConfig"]["candidateCount"] == 2 + + def test_transform_image_generation_request_with_n(self): + """Test request transformation with n parameter""" + request = self.config.transform_image_generation_request( + model="gemini-2.5-flash-image", + prompt="A nano banana", + optional_params={"n": 2}, + litellm_params={}, + headers={}, + ) + assert request["generationConfig"]["candidateCount"] == 2 + + def test_transform_image_generation_response(self): + """Test response transformation""" + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = { + "candidates": [ + { + "content": { + "parts": [ + { + "inlineData": { + "mimeType": "image/png", + "data": "base64_encoded_image_data", + } + } + ] + } + } + ] + } + mock_response.headers = {} + + from litellm.types.utils import ImageResponse + + model_response = ImageResponse() + result = self.config.transform_image_generation_response( + model="gemini-2.5-flash-image", + raw_response=mock_response, + model_response=model_response, + logging_obj=MagicMock(), + request_data={}, + optional_params={}, + litellm_params={}, + encoding=None, + ) + + assert len(result.data) == 1 + assert result.data[0].b64_json == "base64_encoded_image_data" + assert result.data[0].url is None + + def test_transform_image_generation_response_multiple_images(self): + """Test response transformation with multiple images""" + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = { + "candidates": [ + { + "content": { + "parts": [ + { + "inlineData": { + "mimeType": "image/png", + "data": "image1", + } + }, + { + "inlineData": { + "mimeType": "image/png", + "data": "image2", + } + }, + ] + } + } + ] + } + mock_response.headers = {} + + from litellm.types.utils import ImageResponse + + model_response = ImageResponse() + result = self.config.transform_image_generation_response( + model="gemini-2.5-flash-image", + raw_response=mock_response, + model_response=model_response, + logging_obj=MagicMock(), + request_data={}, + optional_params={}, + litellm_params={}, + encoding=None, + ) + + assert len(result.data) == 2 + assert result.data[0].b64_json == "image1" + assert result.data[1].b64_json == "image2" + + +class TestVertexAIImagenImageGenerationConfig: + def setup_method(self): + """Set up test fixtures""" + self.config = VertexAIImagenImageGenerationConfig() + + def test_get_supported_openai_params(self): + """Test get_supported_openai_params returns correct params""" + supported = self.config.get_supported_openai_params("imagegeneration@006") + assert "n" in supported + assert "size" in supported + + def test_map_openai_params_n(self): + """Test mapping n parameter to sampleCount""" + non_default_params = {"n": 3} + optional_params = {} + result = self.config.map_openai_params( + non_default_params, optional_params, "imagegeneration@006", False + ) + assert result.get("sampleCount") == 3 + + def test_map_openai_params_size(self): + """Test mapping size parameter to aspectRatio""" + non_default_params = {"size": "1024x1024"} + optional_params = {} + result = self.config.map_openai_params( + non_default_params, optional_params, "imagegeneration@006", False + ) + assert result.get("aspectRatio") == "1:1" + + def test_map_size_to_aspect_ratio(self): + """Test size to aspect ratio mapping""" + assert self.config._map_size_to_aspect_ratio("1024x1024") == "1:1" + assert self.config._map_size_to_aspect_ratio("1792x1024") == "16:9" + assert self.config._map_size_to_aspect_ratio("unknown") == "1:1" # default + + def test_transform_image_generation_request_basic(self): + """Test basic request transformation""" + request = self.config.transform_image_generation_request( + model="imagegeneration@006", + prompt="A cat", + optional_params={}, + litellm_params={}, + headers={}, + ) + assert "instances" in request + assert "parameters" in request + assert request["instances"][0]["prompt"] == "A cat" + assert request["parameters"]["sampleCount"] == 1 + + def test_transform_image_generation_request_with_params(self): + """Test request transformation with parameters""" + request = self.config.transform_image_generation_request( + model="imagegeneration@006", + prompt="A cat", + optional_params={"sampleCount": 2, "aspectRatio": "16:9"}, + litellm_params={}, + headers={}, + ) + assert request["parameters"]["sampleCount"] == 2 + assert request["parameters"]["aspectRatio"] == "16:9" + + def test_transform_image_generation_response(self): + """Test response transformation""" + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = { + "predictions": [ + {"bytesBase64Encoded": "base64_encoded_image_data"} + ] + } + mock_response.headers = {} + + from litellm.types.utils import ImageResponse + + model_response = ImageResponse() + result = self.config.transform_image_generation_response( + model="imagegeneration@006", + raw_response=mock_response, + model_response=model_response, + logging_obj=MagicMock(), + request_data={}, + optional_params={}, + litellm_params={}, + encoding=None, + ) + + assert len(result.data) == 1 + assert result.data[0].b64_json == "base64_encoded_image_data" + assert result.data[0].url is None + + def test_transform_image_generation_response_multiple_images(self): + """Test response transformation with multiple images""" + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = { + "predictions": [ + {"bytesBase64Encoded": "image1"}, + {"bytesBase64Encoded": "image2"}, + ] + } + mock_response.headers = {} + + from litellm.types.utils import ImageResponse + + model_response = ImageResponse() + result = self.config.transform_image_generation_response( + model="imagegeneration@006", + raw_response=mock_response, + model_response=model_response, + logging_obj=MagicMock(), + request_data={}, + optional_params={}, + litellm_params={}, + encoding=None, + ) + + assert len(result.data) == 2 + assert result.data[0].b64_json == "image1" + assert result.data[1].b64_json == "image2" + + +class TestGetVertexAIImageGenerationConfig: + """Test the router function that selects the correct config""" + + def test_get_gemini_model_config(self): + """Test that Gemini models return Gemini config""" + config = get_vertex_ai_image_generation_config("gemini-2.5-flash-image") + assert isinstance(config, VertexAIGeminiImageGenerationConfig) + + config = get_vertex_ai_image_generation_config("gemini-3-pro-image-preview") + assert isinstance(config, VertexAIGeminiImageGenerationConfig) + + config = get_vertex_ai_image_generation_config( + "vertex_ai/gemini-2.5-flash-image" + ) + assert isinstance(config, VertexAIGeminiImageGenerationConfig) + + def test_get_imagen_model_config(self): + """Test that Imagen models return Imagen config""" + config = get_vertex_ai_image_generation_config("imagegeneration@006") + assert isinstance(config, VertexAIImagenImageGenerationConfig) + + config = get_vertex_ai_image_generation_config("imagen-4.0-generate-001") + assert isinstance(config, VertexAIImagenImageGenerationConfig) + + config = get_vertex_ai_image_generation_config( + "vertex_ai/imagegeneration@006" + ) + assert isinstance(config, VertexAIImagenImageGenerationConfig) + + def test_get_non_gemini_model_config(self): + """Test that non-Gemini models default to Imagen config""" + config = get_vertex_ai_image_generation_config("some-other-model") + assert isinstance(config, VertexAIImagenImageGenerationConfig) + + +class TestVertexAIImageGenerationIntegration: + """Integration tests for Vertex AI image generation""" + + @pytest.mark.skipif( + not os.getenv("VERTEXAI_PROJECT"), + reason="Vertex AI credentials not set", + ) + def test_gemini_image_generation_config_validation(self): + """Test that Gemini config can validate environment""" + config = VertexAIGeminiImageGenerationConfig() + with patch.object( + config, "_resolve_vertex_project", return_value="test-project" + ), patch.object( + config, "_resolve_vertex_location", return_value="us-central1" + ), patch.object( + config, "_ensure_access_token", return_value=("token", None) + ): + headers = config.validate_environment( + headers={}, + model="gemini-2.5-flash-image", + messages=[], + optional_params={}, + litellm_params={}, + ) + assert "Authorization" in headers + + @pytest.mark.skipif( + not os.getenv("VERTEXAI_PROJECT"), + reason="Vertex AI credentials not set", + ) + def test_imagen_image_generation_config_validation(self): + """Test that Imagen config can validate environment""" + config = VertexAIImagenImageGenerationConfig() + with patch.object( + config, "_resolve_vertex_project", return_value="test-project" + ), patch.object( + config, "_resolve_vertex_location", return_value="us-central1" + ), patch.object( + config, "_ensure_access_token", return_value=("token", None) + ): + headers = config.validate_environment( + headers={}, + model="imagegeneration@006", + messages=[], + optional_params={}, + litellm_params={}, + ) + assert "Authorization" in headers + + def test_gemini_get_complete_url(self): + """Test Gemini config URL generation""" + config = VertexAIGeminiImageGenerationConfig() + with patch.object( + config, "_resolve_vertex_project", return_value="test-project" + ), patch.object( + config, "_resolve_vertex_location", return_value="us-central1" + ): + url = config.get_complete_url( + api_base=None, + api_key=None, + model="gemini-2.5-flash-image", + optional_params={}, + litellm_params={}, + ) + assert "test-project" in url + assert "us-central1" in url + assert "gemini-2.5-flash-image" in url + assert "generateContent" in url + + def test_imagen_get_complete_url(self): + """Test Imagen config URL generation""" + config = VertexAIImagenImageGenerationConfig() + with patch.object( + config, "_resolve_vertex_project", return_value="test-project" + ), patch.object( + config, "_resolve_vertex_location", return_value="us-central1" + ): + url = config.get_complete_url( + api_base=None, + api_key=None, + model="imagegeneration@006", + optional_params={}, + litellm_params={}, + ) + assert "test-project" in url + assert "us-central1" in url + assert "imagegeneration@006" in url + assert "predict" in url +