mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 03:31:23 +00:00
Normalize OpenAI SDK BaseModel choices/messages to avoid Pydantic serializer warnings (#18972)
* Normalize BaseModel choices + suppress serializer warnings * Fix ModelResponse normalization and test deps
This commit is contained in:
@@ -35,6 +35,7 @@ jobs:
|
||||
poetry run pip install "google-cloud-aiplatform>=1.38"
|
||||
poetry run pip install "fastapi-offline==1.7.3"
|
||||
poetry run pip install "python-multipart==0.0.18"
|
||||
poetry run pip install "openapi-core"
|
||||
- name: Setup litellm-enterprise as local package
|
||||
run: |
|
||||
cd enterprise
|
||||
|
||||
@@ -45,6 +45,7 @@ install-proxy-dev-ci:
|
||||
install-test-deps: install-proxy-dev
|
||||
poetry run pip install "pytest-retry==1.6.3"
|
||||
poetry run pip install pytest-xdist
|
||||
poetry run pip install openapi-core
|
||||
cd enterprise && poetry run pip install -e . && cd ..
|
||||
|
||||
install-helm-unittest:
|
||||
@@ -100,4 +101,4 @@ test-llm-translation-single: install-test-deps
|
||||
@mkdir -p test-results
|
||||
poetry run pytest tests/llm_translation/$(FILE) \
|
||||
--junitxml=test-results/junit.xml \
|
||||
-v --tb=short --maxfail=100 --timeout=300
|
||||
-v --tb=short --maxfail=100 --timeout=300
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# Package marker for enterprise proxy components.
|
||||
@@ -0,0 +1 @@
|
||||
# Package marker for enterprise proxy common utilities.
|
||||
+36
-7
@@ -1250,6 +1250,14 @@ class Choices(OpenAIObject):
|
||||
params["message"] = message
|
||||
elif isinstance(message, dict):
|
||||
params["message"] = Message(**message)
|
||||
elif isinstance(message, BaseModel):
|
||||
# Normalize provider/OpenAI SDK message models into LiteLLM's Message type.
|
||||
dump = (
|
||||
message.model_dump()
|
||||
if hasattr(message, "model_dump")
|
||||
else message.dict()
|
||||
)
|
||||
params["message"] = Message(**dump)
|
||||
if logprobs is not None:
|
||||
if isinstance(logprobs, dict):
|
||||
params["logprobs"] = ChoiceLogprobs(**logprobs)
|
||||
@@ -1612,6 +1620,12 @@ class ModelResponseBase(OpenAIObject):
|
||||
|
||||
_response_headers: Optional[dict] = None
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
"""Default to exclude_unset to avoid Pydantic serializer warnings for OpenAIObject-derived types."""
|
||||
if "exclude_unset" not in kwargs and "exclude_none" not in kwargs:
|
||||
kwargs["exclude_unset"] = True
|
||||
return super().model_dump(**kwargs)
|
||||
|
||||
|
||||
class ModelResponseStream(ModelResponseBase):
|
||||
choices: List[StreamingChoices]
|
||||
@@ -1651,12 +1665,16 @@ class ModelResponseStream(ModelResponseBase):
|
||||
else:
|
||||
created = created
|
||||
|
||||
if (
|
||||
"usage" in kwargs
|
||||
and kwargs["usage"] is not None
|
||||
and isinstance(kwargs["usage"], dict)
|
||||
):
|
||||
kwargs["usage"] = Usage(**kwargs["usage"])
|
||||
if "usage" in kwargs and kwargs["usage"] is not None:
|
||||
if isinstance(kwargs["usage"], dict):
|
||||
kwargs["usage"] = Usage(**kwargs["usage"])
|
||||
elif isinstance(kwargs["usage"], BaseModel):
|
||||
dump = (
|
||||
kwargs["usage"].model_dump()
|
||||
if hasattr(kwargs["usage"], "model_dump")
|
||||
else kwargs["usage"].dict()
|
||||
)
|
||||
kwargs["usage"] = Usage(**dump)
|
||||
|
||||
kwargs["id"] = id
|
||||
kwargs["created"] = created
|
||||
@@ -1730,6 +1748,13 @@ class ModelResponse(ModelResponseBase):
|
||||
_new_choice = choice # type: ignore
|
||||
elif isinstance(choice, dict):
|
||||
_new_choice = Choices(**choice) # type: ignore
|
||||
elif isinstance(choice, BaseModel):
|
||||
dump = (
|
||||
choice.model_dump()
|
||||
if hasattr(choice, "model_dump")
|
||||
else choice.dict()
|
||||
)
|
||||
_new_choice = Choices(**dump) # type: ignore
|
||||
else:
|
||||
_new_choice = choice
|
||||
new_choices.append(_new_choice)
|
||||
@@ -1748,6 +1773,11 @@ class ModelResponse(ModelResponseBase):
|
||||
if usage is not None:
|
||||
if isinstance(usage, dict):
|
||||
usage = Usage(**usage)
|
||||
elif isinstance(usage, BaseModel):
|
||||
dump = (
|
||||
usage.model_dump() if hasattr(usage, "model_dump") else usage.dict()
|
||||
)
|
||||
usage = Usage(**dump)
|
||||
else:
|
||||
usage = usage
|
||||
elif stream is None or stream is False:
|
||||
@@ -3032,7 +3062,6 @@ class LlmProviders(str, Enum):
|
||||
XIAOMI_MIMO = "xiaomi_mimo"
|
||||
|
||||
|
||||
|
||||
# Create a set of all provider values for quick lookup
|
||||
LlmProvidersSet = {provider.value for provider in LlmProviders}
|
||||
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.types.utils import Choices, Message, ModelResponse
|
||||
|
||||
|
||||
def test_modelresponse_normalizes_openai_base_models() -> None:
|
||||
# OpenAI SDK returns Pydantic BaseModel objects for message/choice.
|
||||
# LiteLLM should normalize these into its own internal `Message` / `Choices` types.
|
||||
from openai.types.chat.chat_completion import Choice as OpenAIChoice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
||||
message = ChatCompletionMessage(role="assistant", content="hi")
|
||||
choice = OpenAIChoice(finish_reason="stop", index=0, message=message, logprobs=None)
|
||||
|
||||
with warnings.catch_warnings(record=True) as captured:
|
||||
warnings.simplefilter("always")
|
||||
response = ModelResponse(model="gpt-4o-mini", choices=[choice])
|
||||
_ = response.model_dump()
|
||||
|
||||
assert isinstance(response.choices[0], Choices)
|
||||
assert isinstance(response.choices[0].message, Message)
|
||||
|
||||
assert not any(
|
||||
"Pydantic serializer warnings" in str(w.message)
|
||||
for w in captured
|
||||
if isinstance(w.message, Warning)
|
||||
)
|
||||
|
||||
|
||||
def test_modelresponse_serialization_avoids_pydantic_warnings() -> None:
|
||||
pytest.importorskip("openai")
|
||||
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
|
||||
|
||||
openai_completion = OpenAIChatCompletion(
|
||||
id="test-1",
|
||||
created=1719868600,
|
||||
model="gpt-4o-mini",
|
||||
object="chat.completion",
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as captured:
|
||||
warnings.simplefilter("always")
|
||||
response = ModelResponse(**openai_completion.model_dump())
|
||||
_ = response.model_dump(exclude_none=True)
|
||||
|
||||
assert not any(
|
||||
"PydanticSerializationUnexpectedValue" in str(w.message)
|
||||
or "Pydantic serializer warnings" in str(w.message)
|
||||
for w in captured
|
||||
)
|
||||
Reference in New Issue
Block a user