Azure - responses api bridge - respect responses/ + Gemini - generate content bridge - handle kwargs + litellm params containing stream (#12224)

* fix(main.py): handle router custom azure model name for responses api bridge

* fix(responses/handler): ensure azure model name is stripped before sending to provider

Fixes model name error

* fix(google_genai/main.py): handle stream=true being set in kwargs

* docs: cleanup icons from sidebar

* fix(test-litellm.yml): add google-genai to test litellmyml

* fix(main.py): strip 'responses/' from bridge

* fix(main.py): fix linting errors

* fix(types/openai.py): allow item to be none

handle azure streaming response

* fix(base.py): allow extra fields + handle azure item = none value in response output item added event

* fix(main.py): correctly handle removing responses/

* test(test_main.py): add unit tests
This commit is contained in:
Krish Dholakia
2025-07-02 13:53:52 -07:00
committed by GitHub
parent 2c60c316ec
commit df49b24bc0
17 changed files with 311 additions and 91 deletions
+1
View File
@@ -30,6 +30,7 @@ jobs:
poetry install --with dev,proxy-dev --extras proxy
poetry run pip install "pytest-retry==1.6.3"
poetry run pip install pytest-xdist
poetry run pip install "google-genai==1.22.0"
- name: Setup litellm-enterprise as local package
run: |
cd enterprise
+1 -1
View File
@@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 🙋‍♂️ Customers / End-User Budgets
# Customers / End-User Budgets
Track spend, set budgets for your customers.
+1 -1
View File
@@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 💰 Setting Team Budgets
# Setting Team Budgets
Track spend, set budgets for your Internal Team
+1 -1
View File
@@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 💰 Budgets, Rate Limits
# Budgets, Rate Limits
Requirements:
@@ -75,9 +75,7 @@ class ResponsesToCompletionBridgeHandler:
custom_llm_provider=custom_llm_provider,
)
def completion(
self, *args, **kwargs
) -> Union[
def completion(self, *args, **kwargs) -> Union[
Coroutine[Any, Any, Union["ModelResponse", "CustomStreamWrapper"]],
"ModelResponse",
"CustomStreamWrapper",
@@ -106,6 +104,7 @@ class ResponsesToCompletionBridgeHandler:
litellm_params=litellm_params,
headers=headers,
litellm_logging_obj=logging_obj,
client=kwargs.get("client"),
)
result = responses(
@@ -121,6 +121,7 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
litellm_params: dict,
headers: dict,
litellm_logging_obj: "LiteLLMLoggingObj",
client: Optional[Any] = None,
) -> dict:
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
@@ -186,6 +187,7 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
"input": input_items,
"litellm_logging_obj": litellm_logging_obj,
**litellm_params,
"client": client,
}
verbose_logger.debug(
+67 -55
View File
@@ -29,7 +29,7 @@ else:
GenerateContentConfigDict = Any
GenerateContentContentListUnionDict = Any
GenerateContentResponse = Any
####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here
base_llm_http_handler = BaseLLMHTTPHandler()
@@ -38,6 +38,7 @@ base_llm_http_handler = BaseLLMHTTPHandler()
class GenerateContentSetupResult(BaseModel):
"""Internal Type - Result of setting up a generate content call"""
model: str
request_body: Dict[str, Any]
custom_llm_provider: str
@@ -53,7 +54,7 @@ class GenerateContentSetupResult(BaseModel):
class GenerateContentHelper:
"""Helper class for Google GenAI generate content operations"""
@staticmethod
def mock_generate_content_response(
mock_response: str = "This is a mock response from Google GenAI generate_content.",
@@ -63,20 +64,17 @@ class GenerateContentHelper:
"text": mock_response,
"candidates": [
{
"content": {
"parts": [{"text": mock_response}],
"role": "model"
},
"content": {"parts": [{"text": mock_response}], "role": "model"},
"finishReason": "STOP",
"index": 0,
"safetyRatings": []
"safetyRatings": [],
}
],
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 20,
"totalTokenCount": 30
}
"totalTokenCount": 30,
},
}
@staticmethod
@@ -86,11 +84,11 @@ class GenerateContentHelper:
config: Optional[GenerateContentConfigDict] = None,
custom_llm_provider: Optional[str] = None,
stream: bool = False,
**kwargs
**kwargs,
) -> GenerateContentSetupResult:
"""
Common setup logic for generate_content calls
Args:
model: The model name
contents: The content to generate from
@@ -99,18 +97,24 @@ class GenerateContentHelper:
stream: Whether this is a streaming call
local_vars: Local variables from the calling function
**kwargs: Additional keyword arguments
Returns:
GenerateContentSetupResult containing all setup information
"""
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get("litellm_logging_obj")
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
"litellm_logging_obj"
)
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
# get llm provider logic
litellm_params = GenericLiteLLMParams(**kwargs)
## MOCK RESPONSE LOGIC (only for non-streaming)
if not stream and litellm_params.mock_response and isinstance(litellm_params.mock_response, str):
if (
not stream
and litellm_params.mock_response
and isinstance(litellm_params.mock_response, str)
):
raise ValueError("Mock response should be handled by caller")
(
@@ -126,11 +130,11 @@ class GenerateContentHelper:
)
# get provider config
generate_content_provider_config: Optional[BaseGoogleGenAIGenerateContentConfig] = (
ProviderConfigManager.get_provider_google_genai_generate_content_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
generate_content_provider_config: Optional[
BaseGoogleGenAIGenerateContentConfig
] = ProviderConfigManager.get_provider_google_genai_generate_content_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
if generate_content_provider_config is None:
@@ -146,28 +150,31 @@ class GenerateContentHelper:
generate_content_config_dict=dict(config or {}),
litellm_params=litellm_params,
litellm_logging_obj=litellm_logging_obj,
litellm_call_id=litellm_call_id
litellm_call_id=litellm_call_id,
)
#########################################################################################
# Construct request body
#########################################################################################
# Create Google Optional Params Config
generate_content_config_dict = generate_content_provider_config.map_generate_content_optional_params(
generate_content_config_dict=config or {},
model=model,
generate_content_config_dict = (
generate_content_provider_config.map_generate_content_optional_params(
generate_content_config_dict=config or {},
model=model,
)
)
request_body = generate_content_provider_config.transform_generate_content_request(
model=model,
contents=contents,
generate_content_config_dict=generate_content_config_dict,
request_body = (
generate_content_provider_config.transform_generate_content_request(
model=model,
contents=contents,
generate_content_config_dict=generate_content_config_dict,
)
)
# Pre Call logging
if litellm_logging_obj is None:
raise ValueError("litellm_logging_obj is required, but got None")
litellm_logging_obj.update_environment_variables(
model=model,
optional_params=dict(generate_content_config_dict),
@@ -185,7 +192,7 @@ class GenerateContentHelper:
generate_content_config_dict=generate_content_config_dict,
litellm_params=litellm_params,
litellm_logging_obj=litellm_logging_obj,
litellm_call_id=litellm_call_id
litellm_call_id=litellm_call_id,
)
@@ -202,7 +209,7 @@ async def agenerate_content(
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs
**kwargs,
) -> Any:
"""
Async: Generate content using Google GenAI
@@ -273,10 +280,12 @@ def generate_content(
local_vars = locals()
try:
_is_async = kwargs.pop("agenerate_content", False) is True
# Check for mock response first
litellm_params = GenericLiteLLMParams(**kwargs)
if litellm_params.mock_response and isinstance(litellm_params.mock_response, str):
if litellm_params.mock_response and isinstance(
litellm_params.mock_response, str
):
return GenerateContentHelper.mock_generate_content_response(
mock_response=litellm_params.mock_response
)
@@ -288,7 +297,7 @@ def generate_content(
config=config,
custom_llm_provider=custom_llm_provider,
stream=False,
**kwargs
**kwargs,
)
# Check if we should use the adapter (when provider config is None)
@@ -301,7 +310,7 @@ def generate_content(
stream=False,
_is_async=_is_async,
litellm_params=setup_result.litellm_params,
**kwargs
**kwargs,
)
# Call the standard handler
@@ -346,7 +355,7 @@ async def agenerate_content_stream(
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs
**kwargs,
) -> Any:
"""
Async: Generate content using Google GenAI with streaming response
@@ -354,7 +363,7 @@ async def agenerate_content_stream(
local_vars = locals()
try:
kwargs["agenerate_content_stream"] = True
# get custom llm provider so we can use this for mapping exceptions
if custom_llm_provider is None:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
@@ -363,24 +372,28 @@ async def agenerate_content_stream(
# Setup the call
setup_result = GenerateContentHelper.setup_generate_content_call(
model=model,
contents=contents,
config=config,
custom_llm_provider=custom_llm_provider,
stream=True,
**kwargs
**{
"model": model,
"contents": contents,
"config": config,
"custom_llm_provider": custom_llm_provider,
"stream": True,
**kwargs,
}
)
# Check if we should use the adapter (when provider config is None)
if setup_result.generate_content_provider_config is None:
# Use the adapter to convert to completion format
return await GenerateContentToCompletionHandler.async_generate_content_handler(
model=setup_result.model,
contents=contents, # type: ignore
config=setup_result.generate_content_config_dict,
litellm_params=setup_result.litellm_params,
stream=True,
**kwargs
return (
await GenerateContentToCompletionHandler.async_generate_content_handler(
model=setup_result.model,
contents=contents, # type: ignore
config=setup_result.generate_content_config_dict,
litellm_params=setup_result.litellm_params,
stream=True,
**kwargs,
)
)
# Call the handler with async enabled and streaming
@@ -401,7 +414,7 @@ async def agenerate_content_stream(
stream=True,
litellm_metadata=kwargs.get("litellm_metadata", {}),
)
except Exception as e:
raise litellm.exception_type(
model=model,
@@ -442,7 +455,7 @@ def generate_content_stream(
config=config,
custom_llm_provider=custom_llm_provider,
stream=True,
**kwargs
**kwargs,
)
# Check if we should use the adapter (when provider config is None)
@@ -455,7 +468,7 @@ def generate_content_stream(
stream=True,
_is_async=_is_async,
litellm_params=setup_result.litellm_params,
**kwargs
**kwargs,
)
# Call the handler with streaming enabled (sync version)
@@ -484,4 +497,3 @@ def generate_content_stream(
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
+28 -6
View File
@@ -20,8 +20,33 @@ class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
return BaseAzureLLM._base_validate_azure_environment(
headers=headers,
litellm_params=litellm_params
headers=headers, litellm_params=litellm_params
)
def get_stripped_model_name(self, model: str) -> str:
# if "responses/" is in the model name, remove it
if "responses/" in model:
model = model.replace("responses/", "")
if "o_series" in model:
model = model.replace("o_series/", "")
return model
def transform_responses_api_request(
self,
model: str,
input: Union[str, ResponseInputParam],
response_api_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""No transform applied since inputs are in OpenAI spec already"""
stripped_model_name = self.get_stripped_model_name(model)
return dict(
ResponsesAPIRequestParams(
model=stripped_model_name,
input=input,
**response_api_optional_request_params,
)
)
def get_complete_url(
@@ -46,11 +71,8 @@ class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
"""
return BaseAzureLLM._get_base_azure_url(
api_base=api_base,
litellm_params=litellm_params,
route="/openai/responses"
api_base=api_base, litellm_params=litellm_params, route="/openai/responses"
)
#########################################################
########## DELETE RESPONSE API TRANSFORMATION ##############
+14 -7
View File
@@ -1008,12 +1008,12 @@ class BaseLLMHTTPHandler:
"""
Shared logic for preparing audio transcription requests.
Returns: (headers, complete_url, data, files)
"""
"""
# Handle the response based on type
from litellm.llms.base_llm.audio_transcription.transformation import (
AudioTranscriptionRequestData,
)
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers or {},
@@ -1038,11 +1038,13 @@ class BaseLLMHTTPHandler:
optional_params=optional_params,
litellm_params=litellm_params,
)
# All providers now return AudioTranscriptionRequestData
if not isinstance(transformed_result, AudioTranscriptionRequestData):
raise ValueError(f"Provider {provider_config.__class__.__name__} must return AudioTranscriptionRequestData")
raise ValueError(
f"Provider {provider_config.__class__.__name__} must return AudioTranscriptionRequestData"
)
data = transformed_result.data
files = transformed_result.files
@@ -1143,7 +1145,9 @@ class BaseLLMHTTPHandler:
headers=headers,
data=data,
files=files,
json=data if files is None and isinstance(data, dict) else None, # Use json param only when no files and data is dict
json=(
data if files is None and isinstance(data, dict) else None
), # Use json param only when no files and data is dict
timeout=timeout,
)
except Exception as e:
@@ -1214,7 +1218,9 @@ class BaseLLMHTTPHandler:
headers=headers,
data=data,
files=files,
json=data if files is None and isinstance(data, dict) else None, # Use json param only when no files and data is dict
json=(
data if files is None and isinstance(data, dict) else None
), # Use json param only when no files and data is dict
timeout=timeout,
)
except Exception as e:
@@ -1432,6 +1438,7 @@ class BaseLLMHTTPHandler:
Handles responses API requests.
When _is_async=True, returns a coroutine instead of making the call directly.
"""
if _is_async:
# Return the async coroutine if called with _is_async=True
return self.async_response_api_handler(
+36 -14
View File
@@ -31,6 +31,7 @@ from typing import (
Literal,
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
@@ -822,6 +823,34 @@ def mock_completion(
raise Exception("Mock completion response failed - {}".format(e))
def responses_api_bridge_check(
model: str,
custom_llm_provider: str,
) -> Tuple[dict, str]:
model_info: Dict[str, Any] = {}
try:
model_info = cast(
dict,
_get_model_info_helper(
model=model, custom_llm_provider=custom_llm_provider
),
)
if model_info.get("mode") is None and model.startswith("responses/"):
model = model.replace("responses/", "")
mode = "responses"
model_info["mode"] = mode
except Exception as e:
verbose_logger.debug("Error getting model info: {}".format(e))
if model.startswith(
"responses/"
): # handle azure models - `azure/responses/<deployment-name>`
model = model.replace("responses/", "")
mode = "responses"
model_info["mode"] = mode
return model_info, model
@tracer.wrap()
@client
def completion( # type: ignore # noqa: PLR0915
@@ -1290,19 +1319,9 @@ def completion( # type: ignore # noqa: PLR0915
)
## RESPONSES API BRIDGE LOGIC ## - check if model has 'mode: responses' in litellm.model_cost map
try:
model_info = _get_model_info_helper(
model=model, custom_llm_provider=custom_llm_provider
)
except Exception as e:
verbose_logger.debug("Error getting model info: {}".format(e))
model_info = {}
if model.startswith(
"responses/"
): # handle azure models - `azure/responses/<deployment-name>`
model = model.split("/")[1]
mode = "responses"
model_info["mode"] = mode
model_info, model = responses_api_bridge_check(
model=model, custom_llm_provider=custom_llm_provider
)
if model_info.get("mode") == "responses":
from litellm.completion_extras import responses_api_bridge
@@ -4939,7 +4958,10 @@ def transcription(
provider_config=provider_config,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider in [LlmProviders.DEEPGRAM.value, LlmProviders.ELEVENLABS.value]:
elif custom_llm_provider in [
LlmProviders.DEEPGRAM.value,
LlmProviders.ELEVENLABS.value,
]:
response = base_llm_http_handler.audio_transcriptions(
model=model,
audio_file=file,
+2 -1
View File
@@ -23,8 +23,9 @@ class LiteLLMPydanticObjectBase(BaseModel):
model_config = ConfigDict(protected_namespaces=())
class BaseLiteLLMOpenAIResponseObject(BaseModel):
model_config = ConfigDict(extra="allow", protected_namespaces=())
def __getitem__(self, key):
return self.__dict__[key]
+1 -1
View File
@@ -1134,7 +1134,7 @@ class ReasoningSummaryTextDeltaEvent(BaseLiteLLMOpenAIResponseObject):
class OutputItemAddedEvent(BaseLiteLLMOpenAIResponseObject):
type: Literal[ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED]
output_index: int
item: dict
item: Optional[dict]
class OutputItemDoneEvent(BaseLiteLLMOpenAIResponseObject):
+6 -1
View File
@@ -162,7 +162,12 @@ class ModelInfoBase(ProviderSpecificModelInfo, total=False):
litellm_provider: Required[str]
mode: Required[
Literal[
"completion", "embedding", "image_generation", "chat", "audio_transcription"
"completion",
"embedding",
"image_generation",
"chat",
"audio_transcription",
"responses",
]
]
tpm: Optional[int]
@@ -0,0 +1,45 @@
#!/usr/bin/env python3
"""
Test to verify the Google GenAI generate_content adapter functionality
"""
import json
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
import json
import os
import sys
import pytest
import litellm
@pytest.mark.asyncio
async def test_agenerate_content_stream():
"""
Test that the agenerate_content_stream function works
"""
from unittest.mock import AsyncMock, patch
from litellm.google_genai.main import (
agenerate_content_stream,
base_llm_http_handler,
)
with patch.object(
base_llm_http_handler, "generate_content_handler", new=AsyncMock()
) as mock_post:
result = await agenerate_content_stream(
model="gemini/gemini-2.0-flash-001",
contents="Hello, world!",
stream=True,
)
mock_post.assert_called_once()
mock_post.call_args.kwargs["stream"] == True
+31
View File
@@ -487,3 +487,34 @@ def test_bedrock_llama():
request["raw_request_body"]["prompt"]
== "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nhi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
def test_responses_api_bridge_check_strips_responses_prefix():
"""Test that responses_api_bridge_check strips 'responses/' prefix and sets mode."""
from litellm.main import responses_api_bridge_check
with patch("litellm.main._get_model_info_helper") as mock_get_model_info:
mock_get_model_info.return_value = {"max_tokens": 4096}
model_info, model = responses_api_bridge_check(
model="responses/gpt-4-responses",
custom_llm_provider="openai",
)
assert model == "gpt-4-responses"
assert model_info["mode"] == "responses"
def test_responses_api_bridge_check_handles_exception():
"""Test that responses_api_bridge_check handles exceptions and still processes responses/ models."""
from litellm.main import responses_api_bridge_check
with patch("litellm.main._get_model_info_helper") as mock_get_model_info:
mock_get_model_info.side_effect = Exception("Model not found")
model_info, model = responses_api_bridge_check(
model="responses/custom-model", custom_llm_provider="custom"
)
assert model == "custom-model"
assert model_info["mode"] == "responses"
+57
View File
@@ -602,3 +602,60 @@ def test_router_should_include_deployment():
assert (
result is True
), "Should return True when matching model with exact model_name"
def test_router_responses_api_bridge():
"""
Test that router.responses_api_bridge returns the correct response
"""
from unittest.mock import MagicMock, patch
from litellm.llms.custom_httpx.http_handler import HTTPHandler
router = litellm.Router(
model_list=[
{
"model_name": "[IP-approved] o3-pro",
"litellm_params": {
"model": "azure/responses/o_series/webinterface-o3-pro",
"api_base": "https://webhook.site/fba79dae-220a-4bb7-9a3a-8caa49604e55",
"api_key": "sk-1234567890",
"api_version": "preview",
"stream": True,
},
"model_info": {
"input_cost_per_token": 0.00002,
"output_cost_per_token": 0.00008,
},
}
],
)
## CONFIRM BRIDGE IS CALLED
with patch.object(litellm, "responses", return_value=AsyncMock()) as mock_responses:
result = router.completion(
model="[IP-approved] o3-pro",
messages=[{"role": "user", "content": "Hello, world!"}],
)
assert mock_responses.call_count == 1
## CONFIRM MODEL NAME IS STRIPPED
client = HTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
try:
result = router.completion(
model="[IP-approved] o3-pro",
messages=[{"role": "user", "content": "Hello, world!"}],
client=client,
num_retries=0,
)
except Exception as e:
print(f"Error: {e}")
assert mock_post.call_count == 1
assert (
mock_post.call_args.kwargs["url"]
== "https://webhook.site/fba79dae-220a-4bb7-9a3a-8caa49604e55/openai/v1/responses?api-version=preview"
)
assert mock_post.call_args.kwargs["json"]["model"] == "webinterface-o3-pro"
@@ -19,3 +19,19 @@ def test_generic_event():
event = GenericEvent(**event)
assert event.type == "test"
assert event.test == "test"
def test_output_item_added_event():
from litellm.types.llms.openai import OutputItemAddedEvent
event = {
"type": "response.output_item.added",
"sequence_number": 4,
"output_index": 1,
"item": None,
}
event = OutputItemAddedEvent(**event)
assert event.type == "response.output_item.added"
assert event.sequence_number == 4
assert event.output_index == 1
assert event.item is None