mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 00:48:01 +00:00
Azure - responses api bridge - respect responses/ + Gemini - generate content bridge - handle kwargs + litellm params containing stream (#12224)
* fix(main.py): handle router custom azure model name for responses api bridge * fix(responses/handler): ensure azure model name is stripped before sending to provider Fixes model name error * fix(google_genai/main.py): handle stream=true being set in kwargs * docs: cleanup icons from sidebar * fix(test-litellm.yml): add google-genai to test litellmyml * fix(main.py): strip 'responses/' from bridge * fix(main.py): fix linting errors * fix(types/openai.py): allow item to be none handle azure streaming response * fix(base.py): allow extra fields + handle azure item = none value in response output item added event * fix(main.py): correctly handle removing responses/ * test(test_main.py): add unit tests
This commit is contained in:
@@ -30,6 +30,7 @@ jobs:
|
||||
poetry install --with dev,proxy-dev --extras proxy
|
||||
poetry run pip install "pytest-retry==1.6.3"
|
||||
poetry run pip install pytest-xdist
|
||||
poetry run pip install "google-genai==1.22.0"
|
||||
- name: Setup litellm-enterprise as local package
|
||||
run: |
|
||||
cd enterprise
|
||||
|
||||
@@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# 🙋♂️ Customers / End-User Budgets
|
||||
# Customers / End-User Budgets
|
||||
|
||||
Track spend, set budgets for your customers.
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# 💰 Setting Team Budgets
|
||||
# Setting Team Budgets
|
||||
|
||||
Track spend, set budgets for your Internal Team
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# 💰 Budgets, Rate Limits
|
||||
# Budgets, Rate Limits
|
||||
|
||||
Requirements:
|
||||
|
||||
|
||||
@@ -75,9 +75,7 @@ class ResponsesToCompletionBridgeHandler:
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self, *args, **kwargs
|
||||
) -> Union[
|
||||
def completion(self, *args, **kwargs) -> Union[
|
||||
Coroutine[Any, Any, Union["ModelResponse", "CustomStreamWrapper"]],
|
||||
"ModelResponse",
|
||||
"CustomStreamWrapper",
|
||||
@@ -106,6 +104,7 @@ class ResponsesToCompletionBridgeHandler:
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
litellm_logging_obj=logging_obj,
|
||||
client=kwargs.get("client"),
|
||||
)
|
||||
|
||||
result = responses(
|
||||
|
||||
@@ -121,6 +121,7 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
litellm_logging_obj: "LiteLLMLoggingObj",
|
||||
client: Optional[Any] = None,
|
||||
) -> dict:
|
||||
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
|
||||
|
||||
@@ -186,6 +187,7 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
|
||||
"input": input_items,
|
||||
"litellm_logging_obj": litellm_logging_obj,
|
||||
**litellm_params,
|
||||
"client": client,
|
||||
}
|
||||
|
||||
verbose_logger.debug(
|
||||
|
||||
@@ -29,7 +29,7 @@ else:
|
||||
GenerateContentConfigDict = Any
|
||||
GenerateContentContentListUnionDict = Any
|
||||
GenerateContentResponse = Any
|
||||
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
# Initialize any necessary instances or variables here
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
@@ -38,6 +38,7 @@ base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
|
||||
class GenerateContentSetupResult(BaseModel):
|
||||
"""Internal Type - Result of setting up a generate content call"""
|
||||
|
||||
model: str
|
||||
request_body: Dict[str, Any]
|
||||
custom_llm_provider: str
|
||||
@@ -53,7 +54,7 @@ class GenerateContentSetupResult(BaseModel):
|
||||
|
||||
class GenerateContentHelper:
|
||||
"""Helper class for Google GenAI generate content operations"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mock_generate_content_response(
|
||||
mock_response: str = "This is a mock response from Google GenAI generate_content.",
|
||||
@@ -63,20 +64,17 @@ class GenerateContentHelper:
|
||||
"text": mock_response,
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [{"text": mock_response}],
|
||||
"role": "model"
|
||||
},
|
||||
"content": {"parts": [{"text": mock_response}], "role": "model"},
|
||||
"finishReason": "STOP",
|
||||
"index": 0,
|
||||
"safetyRatings": []
|
||||
"safetyRatings": [],
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 10,
|
||||
"candidatesTokenCount": 20,
|
||||
"totalTokenCount": 30
|
||||
}
|
||||
"totalTokenCount": 30,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -86,11 +84,11 @@ class GenerateContentHelper:
|
||||
config: Optional[GenerateContentConfigDict] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> GenerateContentSetupResult:
|
||||
"""
|
||||
Common setup logic for generate_content calls
|
||||
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
contents: The content to generate from
|
||||
@@ -99,18 +97,24 @@ class GenerateContentHelper:
|
||||
stream: Whether this is a streaming call
|
||||
local_vars: Local variables from the calling function
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
|
||||
Returns:
|
||||
GenerateContentSetupResult containing all setup information
|
||||
"""
|
||||
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get("litellm_logging_obj")
|
||||
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
|
||||
"litellm_logging_obj"
|
||||
)
|
||||
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
||||
|
||||
|
||||
# get llm provider logic
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
## MOCK RESPONSE LOGIC (only for non-streaming)
|
||||
if not stream and litellm_params.mock_response and isinstance(litellm_params.mock_response, str):
|
||||
if (
|
||||
not stream
|
||||
and litellm_params.mock_response
|
||||
and isinstance(litellm_params.mock_response, str)
|
||||
):
|
||||
raise ValueError("Mock response should be handled by caller")
|
||||
|
||||
(
|
||||
@@ -126,11 +130,11 @@ class GenerateContentHelper:
|
||||
)
|
||||
|
||||
# get provider config
|
||||
generate_content_provider_config: Optional[BaseGoogleGenAIGenerateContentConfig] = (
|
||||
ProviderConfigManager.get_provider_google_genai_generate_content_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
generate_content_provider_config: Optional[
|
||||
BaseGoogleGenAIGenerateContentConfig
|
||||
] = ProviderConfigManager.get_provider_google_genai_generate_content_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
if generate_content_provider_config is None:
|
||||
@@ -146,28 +150,31 @@ class GenerateContentHelper:
|
||||
generate_content_config_dict=dict(config or {}),
|
||||
litellm_params=litellm_params,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
litellm_call_id=litellm_call_id
|
||||
litellm_call_id=litellm_call_id,
|
||||
)
|
||||
|
||||
|
||||
#########################################################################################
|
||||
# Construct request body
|
||||
#########################################################################################
|
||||
# Create Google Optional Params Config
|
||||
generate_content_config_dict = generate_content_provider_config.map_generate_content_optional_params(
|
||||
generate_content_config_dict=config or {},
|
||||
model=model,
|
||||
generate_content_config_dict = (
|
||||
generate_content_provider_config.map_generate_content_optional_params(
|
||||
generate_content_config_dict=config or {},
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
request_body = generate_content_provider_config.transform_generate_content_request(
|
||||
model=model,
|
||||
contents=contents,
|
||||
generate_content_config_dict=generate_content_config_dict,
|
||||
request_body = (
|
||||
generate_content_provider_config.transform_generate_content_request(
|
||||
model=model,
|
||||
contents=contents,
|
||||
generate_content_config_dict=generate_content_config_dict,
|
||||
)
|
||||
)
|
||||
|
||||
# Pre Call logging
|
||||
if litellm_logging_obj is None:
|
||||
raise ValueError("litellm_logging_obj is required, but got None")
|
||||
|
||||
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
optional_params=dict(generate_content_config_dict),
|
||||
@@ -185,7 +192,7 @@ class GenerateContentHelper:
|
||||
generate_content_config_dict=generate_content_config_dict,
|
||||
litellm_params=litellm_params,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
litellm_call_id=litellm_call_id
|
||||
litellm_call_id=litellm_call_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -202,7 +209,7 @@ async def agenerate_content(
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Async: Generate content using Google GenAI
|
||||
@@ -273,10 +280,12 @@ def generate_content(
|
||||
local_vars = locals()
|
||||
try:
|
||||
_is_async = kwargs.pop("agenerate_content", False) is True
|
||||
|
||||
|
||||
# Check for mock response first
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
if litellm_params.mock_response and isinstance(litellm_params.mock_response, str):
|
||||
if litellm_params.mock_response and isinstance(
|
||||
litellm_params.mock_response, str
|
||||
):
|
||||
return GenerateContentHelper.mock_generate_content_response(
|
||||
mock_response=litellm_params.mock_response
|
||||
)
|
||||
@@ -288,7 +297,7 @@ def generate_content(
|
||||
config=config,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
stream=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Check if we should use the adapter (when provider config is None)
|
||||
@@ -301,7 +310,7 @@ def generate_content(
|
||||
stream=False,
|
||||
_is_async=_is_async,
|
||||
litellm_params=setup_result.litellm_params,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Call the standard handler
|
||||
@@ -346,7 +355,7 @@ async def agenerate_content_stream(
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Async: Generate content using Google GenAI with streaming response
|
||||
@@ -354,7 +363,7 @@ async def agenerate_content_stream(
|
||||
local_vars = locals()
|
||||
try:
|
||||
kwargs["agenerate_content_stream"] = True
|
||||
|
||||
|
||||
# get custom llm provider so we can use this for mapping exceptions
|
||||
if custom_llm_provider is None:
|
||||
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
@@ -363,24 +372,28 @@ async def agenerate_content_stream(
|
||||
|
||||
# Setup the call
|
||||
setup_result = GenerateContentHelper.setup_generate_content_call(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=config,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
stream=True,
|
||||
**kwargs
|
||||
**{
|
||||
"model": model,
|
||||
"contents": contents,
|
||||
"config": config,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"stream": True,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
# Check if we should use the adapter (when provider config is None)
|
||||
if setup_result.generate_content_provider_config is None:
|
||||
# Use the adapter to convert to completion format
|
||||
return await GenerateContentToCompletionHandler.async_generate_content_handler(
|
||||
model=setup_result.model,
|
||||
contents=contents, # type: ignore
|
||||
config=setup_result.generate_content_config_dict,
|
||||
litellm_params=setup_result.litellm_params,
|
||||
stream=True,
|
||||
**kwargs
|
||||
return (
|
||||
await GenerateContentToCompletionHandler.async_generate_content_handler(
|
||||
model=setup_result.model,
|
||||
contents=contents, # type: ignore
|
||||
config=setup_result.generate_content_config_dict,
|
||||
litellm_params=setup_result.litellm_params,
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
# Call the handler with async enabled and streaming
|
||||
@@ -401,7 +414,7 @@ async def agenerate_content_stream(
|
||||
stream=True,
|
||||
litellm_metadata=kwargs.get("litellm_metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=model,
|
||||
@@ -442,7 +455,7 @@ def generate_content_stream(
|
||||
config=config,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
stream=True,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Check if we should use the adapter (when provider config is None)
|
||||
@@ -455,7 +468,7 @@ def generate_content_stream(
|
||||
stream=True,
|
||||
_is_async=_is_async,
|
||||
litellm_params=setup_result.litellm_params,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Call the handler with streaming enabled (sync version)
|
||||
@@ -484,4 +497,3 @@ def generate_content_stream(
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,8 +20,33 @@ class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||||
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
return BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers,
|
||||
litellm_params=litellm_params
|
||||
headers=headers, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
def get_stripped_model_name(self, model: str) -> str:
|
||||
# if "responses/" is in the model name, remove it
|
||||
if "responses/" in model:
|
||||
model = model.replace("responses/", "")
|
||||
if "o_series" in model:
|
||||
model = model.replace("o_series/", "")
|
||||
return model
|
||||
|
||||
def transform_responses_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
response_api_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""No transform applied since inputs are in OpenAI spec already"""
|
||||
stripped_model_name = self.get_stripped_model_name(model)
|
||||
return dict(
|
||||
ResponsesAPIRequestParams(
|
||||
model=stripped_model_name,
|
||||
input=input,
|
||||
**response_api_optional_request_params,
|
||||
)
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
@@ -46,11 +71,8 @@ class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||||
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
|
||||
"""
|
||||
return BaseAzureLLM._get_base_azure_url(
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
route="/openai/responses"
|
||||
api_base=api_base, litellm_params=litellm_params, route="/openai/responses"
|
||||
)
|
||||
|
||||
|
||||
#########################################################
|
||||
########## DELETE RESPONSE API TRANSFORMATION ##############
|
||||
|
||||
@@ -1008,12 +1008,12 @@ class BaseLLMHTTPHandler:
|
||||
"""
|
||||
Shared logic for preparing audio transcription requests.
|
||||
Returns: (headers, complete_url, data, files)
|
||||
"""
|
||||
"""
|
||||
# Handle the response based on type
|
||||
from litellm.llms.base_llm.audio_transcription.transformation import (
|
||||
AudioTranscriptionRequestData,
|
||||
)
|
||||
|
||||
|
||||
headers = provider_config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers or {},
|
||||
@@ -1038,11 +1038,13 @@ class BaseLLMHTTPHandler:
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
|
||||
# All providers now return AudioTranscriptionRequestData
|
||||
if not isinstance(transformed_result, AudioTranscriptionRequestData):
|
||||
raise ValueError(f"Provider {provider_config.__class__.__name__} must return AudioTranscriptionRequestData")
|
||||
|
||||
raise ValueError(
|
||||
f"Provider {provider_config.__class__.__name__} must return AudioTranscriptionRequestData"
|
||||
)
|
||||
|
||||
data = transformed_result.data
|
||||
files = transformed_result.files
|
||||
|
||||
@@ -1143,7 +1145,9 @@ class BaseLLMHTTPHandler:
|
||||
headers=headers,
|
||||
data=data,
|
||||
files=files,
|
||||
json=data if files is None and isinstance(data, dict) else None, # Use json param only when no files and data is dict
|
||||
json=(
|
||||
data if files is None and isinstance(data, dict) else None
|
||||
), # Use json param only when no files and data is dict
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1214,7 +1218,9 @@ class BaseLLMHTTPHandler:
|
||||
headers=headers,
|
||||
data=data,
|
||||
files=files,
|
||||
json=data if files is None and isinstance(data, dict) else None, # Use json param only when no files and data is dict
|
||||
json=(
|
||||
data if files is None and isinstance(data, dict) else None
|
||||
), # Use json param only when no files and data is dict
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1432,6 +1438,7 @@ class BaseLLMHTTPHandler:
|
||||
Handles responses API requests.
|
||||
When _is_async=True, returns a coroutine instead of making the call directly.
|
||||
"""
|
||||
|
||||
if _is_async:
|
||||
# Return the async coroutine if called with _is_async=True
|
||||
return self.async_response_api_handler(
|
||||
|
||||
+36
-14
@@ -31,6 +31,7 @@ from typing import (
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
@@ -822,6 +823,34 @@ def mock_completion(
|
||||
raise Exception("Mock completion response failed - {}".format(e))
|
||||
|
||||
|
||||
def responses_api_bridge_check(
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
) -> Tuple[dict, str]:
|
||||
model_info: Dict[str, Any] = {}
|
||||
try:
|
||||
model_info = cast(
|
||||
dict,
|
||||
_get_model_info_helper(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
),
|
||||
)
|
||||
if model_info.get("mode") is None and model.startswith("responses/"):
|
||||
model = model.replace("responses/", "")
|
||||
mode = "responses"
|
||||
model_info["mode"] = mode
|
||||
except Exception as e:
|
||||
verbose_logger.debug("Error getting model info: {}".format(e))
|
||||
|
||||
if model.startswith(
|
||||
"responses/"
|
||||
): # handle azure models - `azure/responses/<deployment-name>`
|
||||
model = model.replace("responses/", "")
|
||||
mode = "responses"
|
||||
model_info["mode"] = mode
|
||||
return model_info, model
|
||||
|
||||
|
||||
@tracer.wrap()
|
||||
@client
|
||||
def completion( # type: ignore # noqa: PLR0915
|
||||
@@ -1290,19 +1319,9 @@ def completion( # type: ignore # noqa: PLR0915
|
||||
)
|
||||
|
||||
## RESPONSES API BRIDGE LOGIC ## - check if model has 'mode: responses' in litellm.model_cost map
|
||||
try:
|
||||
model_info = _get_model_info_helper(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug("Error getting model info: {}".format(e))
|
||||
model_info = {}
|
||||
if model.startswith(
|
||||
"responses/"
|
||||
): # handle azure models - `azure/responses/<deployment-name>`
|
||||
model = model.split("/")[1]
|
||||
mode = "responses"
|
||||
model_info["mode"] = mode
|
||||
model_info, model = responses_api_bridge_check(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
if model_info.get("mode") == "responses":
|
||||
from litellm.completion_extras import responses_api_bridge
|
||||
@@ -4939,7 +4958,10 @@ def transcription(
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif custom_llm_provider in [LlmProviders.DEEPGRAM.value, LlmProviders.ELEVENLABS.value]:
|
||||
elif custom_llm_provider in [
|
||||
LlmProviders.DEEPGRAM.value,
|
||||
LlmProviders.ELEVENLABS.value,
|
||||
]:
|
||||
response = base_llm_http_handler.audio_transcriptions(
|
||||
model=model,
|
||||
audio_file=file,
|
||||
|
||||
@@ -23,8 +23,9 @@ class LiteLLMPydanticObjectBase(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
|
||||
class BaseLiteLLMOpenAIResponseObject(BaseModel):
|
||||
model_config = ConfigDict(extra="allow", protected_namespaces=())
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.__dict__[key]
|
||||
|
||||
|
||||
@@ -1134,7 +1134,7 @@ class ReasoningSummaryTextDeltaEvent(BaseLiteLLMOpenAIResponseObject):
|
||||
class OutputItemAddedEvent(BaseLiteLLMOpenAIResponseObject):
|
||||
type: Literal[ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED]
|
||||
output_index: int
|
||||
item: dict
|
||||
item: Optional[dict]
|
||||
|
||||
|
||||
class OutputItemDoneEvent(BaseLiteLLMOpenAIResponseObject):
|
||||
|
||||
@@ -162,7 +162,12 @@ class ModelInfoBase(ProviderSpecificModelInfo, total=False):
|
||||
litellm_provider: Required[str]
|
||||
mode: Required[
|
||||
Literal[
|
||||
"completion", "embedding", "image_generation", "chat", "audio_transcription"
|
||||
"completion",
|
||||
"embedding",
|
||||
"image_generation",
|
||||
"chat",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
]
|
||||
]
|
||||
tpm: Optional[int]
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test to verify the Google GenAI generate_content adapter functionality
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agenerate_content_stream():
|
||||
"""
|
||||
Test that the agenerate_content_stream function works
|
||||
"""
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from litellm.google_genai.main import (
|
||||
agenerate_content_stream,
|
||||
base_llm_http_handler,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
base_llm_http_handler, "generate_content_handler", new=AsyncMock()
|
||||
) as mock_post:
|
||||
result = await agenerate_content_stream(
|
||||
model="gemini/gemini-2.0-flash-001",
|
||||
contents="Hello, world!",
|
||||
stream=True,
|
||||
)
|
||||
mock_post.assert_called_once()
|
||||
mock_post.call_args.kwargs["stream"] == True
|
||||
@@ -487,3 +487,34 @@ def test_bedrock_llama():
|
||||
request["raw_request_body"]["prompt"]
|
||||
== "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nhi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
|
||||
def test_responses_api_bridge_check_strips_responses_prefix():
|
||||
"""Test that responses_api_bridge_check strips 'responses/' prefix and sets mode."""
|
||||
from litellm.main import responses_api_bridge_check
|
||||
|
||||
with patch("litellm.main._get_model_info_helper") as mock_get_model_info:
|
||||
mock_get_model_info.return_value = {"max_tokens": 4096}
|
||||
|
||||
model_info, model = responses_api_bridge_check(
|
||||
model="responses/gpt-4-responses",
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
assert model == "gpt-4-responses"
|
||||
assert model_info["mode"] == "responses"
|
||||
|
||||
|
||||
def test_responses_api_bridge_check_handles_exception():
|
||||
"""Test that responses_api_bridge_check handles exceptions and still processes responses/ models."""
|
||||
from litellm.main import responses_api_bridge_check
|
||||
|
||||
with patch("litellm.main._get_model_info_helper") as mock_get_model_info:
|
||||
mock_get_model_info.side_effect = Exception("Model not found")
|
||||
|
||||
model_info, model = responses_api_bridge_check(
|
||||
model="responses/custom-model", custom_llm_provider="custom"
|
||||
)
|
||||
|
||||
assert model == "custom-model"
|
||||
assert model_info["mode"] == "responses"
|
||||
|
||||
@@ -602,3 +602,60 @@ def test_router_should_include_deployment():
|
||||
assert (
|
||||
result is True
|
||||
), "Should return True when matching model with exact model_name"
|
||||
|
||||
|
||||
def test_router_responses_api_bridge():
|
||||
"""
|
||||
Test that router.responses_api_bridge returns the correct response
|
||||
"""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "[IP-approved] o3-pro",
|
||||
"litellm_params": {
|
||||
"model": "azure/responses/o_series/webinterface-o3-pro",
|
||||
"api_base": "https://webhook.site/fba79dae-220a-4bb7-9a3a-8caa49604e55",
|
||||
"api_key": "sk-1234567890",
|
||||
"api_version": "preview",
|
||||
"stream": True,
|
||||
},
|
||||
"model_info": {
|
||||
"input_cost_per_token": 0.00002,
|
||||
"output_cost_per_token": 0.00008,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
## CONFIRM BRIDGE IS CALLED
|
||||
with patch.object(litellm, "responses", return_value=AsyncMock()) as mock_responses:
|
||||
result = router.completion(
|
||||
model="[IP-approved] o3-pro",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
)
|
||||
assert mock_responses.call_count == 1
|
||||
|
||||
## CONFIRM MODEL NAME IS STRIPPED
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
|
||||
try:
|
||||
result = router.completion(
|
||||
model="[IP-approved] o3-pro",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
client=client,
|
||||
num_retries=0,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
assert mock_post.call_count == 1
|
||||
assert (
|
||||
mock_post.call_args.kwargs["url"]
|
||||
== "https://webhook.site/fba79dae-220a-4bb7-9a3a-8caa49604e55/openai/v1/responses?api-version=preview"
|
||||
)
|
||||
assert mock_post.call_args.kwargs["json"]["model"] == "webinterface-o3-pro"
|
||||
|
||||
@@ -19,3 +19,19 @@ def test_generic_event():
|
||||
event = GenericEvent(**event)
|
||||
assert event.type == "test"
|
||||
assert event.test == "test"
|
||||
|
||||
|
||||
def test_output_item_added_event():
|
||||
from litellm.types.llms.openai import OutputItemAddedEvent
|
||||
|
||||
event = {
|
||||
"type": "response.output_item.added",
|
||||
"sequence_number": 4,
|
||||
"output_index": 1,
|
||||
"item": None,
|
||||
}
|
||||
event = OutputItemAddedEvent(**event)
|
||||
assert event.type == "response.output_item.added"
|
||||
assert event.sequence_number == 4
|
||||
assert event.output_index == 1
|
||||
assert event.item is None
|
||||
|
||||
Reference in New Issue
Block a user