diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 57d30a404c..8e9504ee93 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -385,6 +385,81 @@ def anthropic_pt( prompt += f"{AnthropicConstants.AI_PROMPT.value}" return prompt + +def _load_image_from_url(image_url: str): + """ + Loads an image from a URL. + + Args: + image_url (str): The URL of the image. + + Returns: + Image: The loaded image. + """ + from io import BytesIO + try: + from PIL import Image + except: + raise Exception("gemini image conversion failed please run `pip install Pillow`") + + # Download the image from the URL + response = requests.get(image_url) + image = Image.open(BytesIO(response.content)) + + return image + + +def _gemini_vision_convert_messages(messages: list): + """ + Converts given messages for GPT-4 Vision to Gemini format. + + Args: + messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type: + - If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt. + - If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images. + + Returns: + tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images). + """ + try: + from PIL import Image + except: + raise Exception("gemini image conversion failed please run `pip install Pillow`") + + try: + + # given messages for gpt-4 vision, convert them for gemini + # https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb + prompt = "" + images = [] + for message in messages: + if isinstance(message["content"], str): + prompt += message["content"] + elif isinstance(message["content"], list): + # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models + for element in message["content"]: + if isinstance(element, dict): + if element["type"] == "text": + prompt += element["text"] + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + images.append(image_url) + # processing images passed to gemini + processed_images = [] + for img in images: + if "https:/" in img: + # Case 1: Image from URL + image = _load_image_from_url(img) + processed_images.append(image) + else: + # Case 2: Image filepath (e.g. temp.jpeg) given + image = Image.open(img) + processed_images.append(image) + content = [prompt] + processed_images + return content + except Exception as e: + raise e + def gemini_text_image_pt(messages: list): """ @@ -511,7 +586,10 @@ def prompt_factory( messages=messages, prompt_format=prompt_format, chat_template=chat_template ) elif custom_llm_provider == "gemini": - return gemini_text_image_pt(messages=messages) + if model == "gemini-pro-vision": + return _gemini_vision_convert_messages(messages=messages) + else: + return gemini_text_image_pt(messages=messages) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages) diff --git a/litellm/tests/test_google_ai_studio_gemini.py b/litellm/tests/test_google_ai_studio_gemini.py new file mode 100644 index 0000000000..e9aa00d4a3 --- /dev/null +++ b/litellm/tests/test_google_ai_studio_gemini.py @@ -0,0 +1,33 @@ +import os, sys, traceback + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from dotenv import load_dotenv + +def generate_text(): + try: + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + } + } + ] + } + ] + response = litellm.completion(model="gemini/gemini-pro-vision", messages=messages) + print(response) + except Exception as exception: + raise Exception("An error occurred during text generation:", exception) + +generate_text() diff --git a/litellm/utils.py b/litellm/utils.py index 8f93fb620d..4e81418767 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4012,7 +4012,10 @@ def get_llm_provider( api_base = "https://api.voyageai.com/v1" dynamic_api_key = get_secret("VOYAGE_API_KEY") return model, custom_llm_provider, dynamic_api_key, api_base - + elif model.split("/", 1)[0] in litellm.provider_list: + custom_llm_provider = model.split("/", 1)[0] + model = model.split("/", 1)[1] + return model, custom_llm_provider, dynamic_api_key, api_base # check if api base is a known openai compatible endpoint if api_base: for endpoint in litellm.openai_compatible_endpoints: diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 5745b42479..2cd1f22176 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -583,6 +583,22 @@ "litellm_provider": "palm", "mode": "completion" }, + "gemini/gemini-pro": { + "max_tokens": 30720, + "max_output_tokens": 2048, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "gemini", + "mode": "chat" + }, + "gemini/gemini-pro-vision": { + "max_tokens": 30720, + "max_output_tokens": 2048, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "gemini", + "mode": "chat" + }, "command-nightly": { "max_tokens": 4096, "input_cost_per_token": 0.000015,