Files
litellm/tests/test_litellm/proxy/common_utils/test_cache_codec.py
T

127 lines
5.1 KiB
Python

import logging
from typing import Optional
from unittest.mock import patch
import pytest
from pydantic import BaseModel, ValidationError
from litellm.proxy.common_utils.cache_pydantic_utils import CacheCodec
class _SampleModel(BaseModel):
name: str
count: Optional[int] = None
class _SampleSubModel(_SampleModel):
pass
class TestCacheCodecSerialize:
def test_without_model_type_base_model_dumped_json_safe(self):
m = _SampleModel(name="a", count=1)
out = CacheCodec.serialize(m)
assert out == {"name": "a", "count": 1}
def test_without_model_type_dict_unchanged(self):
d = {"name": "x"}
assert CacheCodec.serialize(d) is d
def test_without_model_type_primitive_unchanged(self):
assert CacheCodec.serialize(42) == 42
def test_with_model_type_dict_validated_and_dumped(self):
out = CacheCodec.serialize({"name": "b", "count": 2}, model_type=_SampleModel)
assert out == {"name": "b", "count": 2}
def test_with_model_type_base_model_validated_and_dumped(self):
m = _SampleModel(name="c", count=None)
out = CacheCodec.serialize(m, model_type=_SampleModel)
assert out == {"name": "c"}
def test_with_model_type_exclude_none_on_dump(self):
out = CacheCodec.serialize({"name": "d"}, model_type=_SampleModel)
assert out == {"name": "d"}
assert "count" not in out
def test_with_model_type_non_dict_non_model_passthrough(self):
assert CacheCodec.serialize("raw", model_type=_SampleModel) == "raw"
def test_with_model_type_invalid_dict_raises(self):
with pytest.raises(ValidationError):
CacheCodec.serialize({"count": 1}, model_type=_SampleModel)
def test_with_model_type_already_correct_instance_skips_revalidation(self):
"""Fast-path: value is already model_type — model_validate must NOT be called."""
m = _SampleModel(name="fast", count=7)
with patch.object(_SampleModel, "model_validate", wraps=_SampleModel.model_validate) as mock_validate:
out = CacheCodec.serialize(m, model_type=_SampleModel)
assert out == {"name": "fast", "count": 7}
mock_validate.assert_not_called()
def test_with_model_type_subclass_instance_skips_revalidation(self):
"""Subclass is isinstance of base → should also take the fast path."""
sub = _SampleSubModel(name="sub", count=2)
with patch.object(_SampleModel, "model_validate", wraps=_SampleModel.model_validate) as mock_validate:
out = CacheCodec.serialize(sub, model_type=_SampleModel)
assert out == {"name": "sub", "count": 2}
mock_validate.assert_not_called()
def test_with_model_type_dict_input_goes_through_model_validate(self):
"""A dict value (not yet an instance) must still go through model_validate."""
raw = {"name": "via-dict", "count": 5}
with patch.object(
_SampleModel, "model_validate", wraps=_SampleModel.model_validate
) as mock_validate:
out = CacheCodec.serialize(raw, model_type=_SampleModel)
assert out == {"name": "via-dict", "count": 5}
mock_validate.assert_called_once()
def test_with_model_type_incompatible_model_raises_validation_error(self):
"""Passing a BaseModel whose fields don't satisfy model_type's required fields raises.
_IncompatibleModel only has `foo: int`, so when Pydantic v2 extracts its
data and validates it against _SampleModel (which requires `name: str`),
a ValidationError is raised.
"""
class _IncompatibleModel(BaseModel):
foo: int # missing required 'name' field of _SampleModel
with pytest.raises(ValidationError):
CacheCodec.serialize(_IncompatibleModel(foo=1), model_type=_SampleModel)
class TestCacheCodecDeserialize:
def test_none_returns_none(self):
assert CacheCodec.deserialize(None, _SampleModel) is None
def test_dict_validates_to_model(self):
m = CacheCodec.deserialize({"name": "e", "count": 3}, _SampleModel)
assert isinstance(m, _SampleModel)
assert m.name == "e"
assert m.count == 3
def test_instance_same_type_returned_as_is(self):
original = _SampleModel(name="f")
m = CacheCodec.deserialize(original, _SampleModel)
assert m is original
def test_subclass_instance_accepted(self):
sub = _SampleSubModel(name="g")
m = CacheCodec.deserialize(sub, _SampleModel)
assert m is sub
def test_wrong_type_returns_none(self):
assert CacheCodec.deserialize("not-a-dict", _SampleModel) is None
def test_invalid_dict_returns_none_and_logs_warning(self, caplog):
with caplog.at_level(logging.WARNING, logger="LiteLLM Proxy"):
out = CacheCodec.deserialize({"count": 1}, _SampleModel)
assert out is None
assert any(
"CacheCodec.deserialize" in r.message and "_SampleModel" in r.message
for r in caplog.records
if r.levelno >= logging.WARNING
), f"Expected deserialize validation warning. Records: {[r.message for r in caplog.records]}"