Files
litellm/tests/test_litellm/integrations/test_anthropic_cache_control_hook.py
T
2026-03-11 18:31:20 +05:30

1076 lines
40 KiB
Python

import datetime
import json
import os
import sys
import unittest
from typing import List, Optional, Tuple
from unittest.mock import ANY, MagicMock, Mock, patch
import httpx
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import litellm
from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_system_message():
# Use patch.dict to mock environment variables instead of setting them directly
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
# Mock response data
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Here is my analysis of the key terms and conditions...",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 100,
"outputTokens": 200,
"totalTokens": 300,
"cacheReadInputTokens": 100,
"cacheWriteInputTokens": 200,
},
}
mock_response.status_code = 200
# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
response = await litellm.acompletion(
model="bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an AI assistant tasked with analyzing legal documents.",
},
{
"type": "text",
"text": "Here is the full text of a complex legal agreement",
},
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
],
cache_control_injection_points=[
{
"location": "message",
"role": "system",
},
],
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
print("request_body: ", json.dumps(request_body, indent=4))
# Verify that cache control was applied (Bedrock transforms it to a separate item)
cache_control_count = sum(
1
for item in request_body["system"]
if isinstance(item, dict) and "cachePoint" in item
)
assert cache_control_count == 1, f"Expected exactly 1 cache control point, found {cache_control_count}"
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_user_message():
# Use patch.dict to mock environment variables instead of setting them directly
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
# Mock response data
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Here is my analysis of the key terms and conditions...",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 100,
"outputTokens": 200,
"totalTokens": 300,
"cacheReadInputTokens": 100,
"cacheWriteInputTokens": 200,
},
}
mock_response.status_code = 200
# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
response = await litellm.acompletion(
model="bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
messages=[
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "You are an AI assistant tasked with analyzing legal documents.",
},
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement? <very_long_text>",
},
],
cache_control_injection_points=[
{
"location": "message",
"role": "user",
},
],
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
print("request_body: ", json.dumps(request_body, indent=4))
# Verify the request body
assert request_body["messages"][1]["content"][1]["cachePoint"] == {
"type": "default"
}
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_negative_indices():
"""
Test the bug fix for handling negative indices in cache control injection points.
This test verifies that negative indices (-1, -2) are properly converted to positive indices
and cache control is applied to the correct messages.
"""
# Use patch.dict to mock environment variables instead of setting them directly
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
# Mock response data
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Here is my analysis of the key terms and conditions...",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 100,
"outputTokens": 200,
"totalTokens": 300,
"cacheReadInputTokens": 100,
"cacheWriteInputTokens": 200,
},
}
mock_response.status_code = 200
# Mock AsyncHTTPHandler.post method
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
# Test with multiple messages and negative indices
response = await litellm.acompletion(
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
messages=[
{
"role": "system",
"content": "You are an AI assistant tasked with analyzing legal documents.",
},
{
"role": "user",
"content": "Here is the first part of the document.",
},
{
"role": "assistant",
"content": "I understand. Please provide the document.",
},
{
"role": "user",
"content": "Here is the full legal document text that should be cached.",
},
],
cache_control_injection_points=[
{
"location": "message",
"index": -1, # Should target the last message (index 3)
},
{
"location": "message",
"index": -2, # Should target the second-to-last message (index 2)
},
],
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
print("request_body: ", json.dumps(request_body, indent=4))
# The input `messages` has 4 elements. After removing the system message,
# the `request_body["messages"]` will have 3 elements (indices 0, 1, 2).
# Verify the last message (input index -1 -> request index 2) has cache control
last_message_content = request_body["messages"][2]["content"]
assert isinstance(
last_message_content, list
), "Last message content should be a list"
assert any(
"cachePoint" in item
for item in last_message_content
if isinstance(item, dict)
), "CachePoint missing in last message"
# Note: Based on debug output, the hook correctly applies cache control to both messages,
# but the Bedrock API transformation appears to only preserve cache control for user messages,
# not assistant messages. This is a limitation of the API transformation layer.
#
# The second-to-last message (assistant) gets cache_control from the hook but loses it
# during API transformation. This test documents this behavior.
second_last_message_content = request_body["messages"][1]["content"]
assert isinstance(
second_last_message_content, list
), "Second-to-last message content should be a list"
# Check if assistant message cache control is preserved (currently it's not)
assistant_has_cache_control = any(
"cachePoint" in item
for item in second_last_message_content
if isinstance(item, dict)
)
print(
f"Assistant message has cache control in final request: {assistant_has_cache_control}"
)
# Verify the first user message (request index 0) was NOT modified
first_user_message_content = request_body["messages"][0]["content"]
assert isinstance(
first_user_message_content, list
), "First user message content should be a list"
assert not any(
"cachePoint" in item
for item in first_user_message_content
if isinstance(item, dict)
), "CachePoint unexpectedly found in first user message"
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_out_of_bounds_logging():
"""
Test that warning logs are generated when out-of-bounds indices are used.
This verifies that the verbose_logger.warning is called with the correct message.
"""
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
# Mock response data
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Response",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 50,
"outputTokens": 100,
"totalTokens": 150,
},
}
mock_response.status_code = 200
client = AsyncHTTPHandler()
# Mock the verbose_logger to capture warning calls
with patch(
"litellm.integrations.anthropic_cache_control_hook.verbose_logger"
) as mock_logger:
with patch.object(client, "post", return_value=mock_response) as mock_post:
messages = [
{"role": "user", "content": "Message 1"},
{"role": "user", "content": "Message 2"},
]
await litellm.acompletion(
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
messages=messages,
cache_control_injection_points=[
{"location": "message", "index": 10}
], # Out of bounds index
client=client,
)
# Verify that warning was called with the expected message
mock_logger.warning.assert_called_once()
warning_call = mock_logger.warning.call_args[0][0]
# Check that the warning message contains the expected information
assert (
"AnthropicCacheControlHook: Provided index 10 is out of bounds"
in warning_call
)
assert "message list of length 2" in warning_call
assert "Targeted index was 10" in warning_call
assert "Skipping cache control injection for this point" in warning_call
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_negative_out_of_bounds_logging():
"""
Test that warning logs are generated for negative indices that are out of bounds.
"""
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
# Mock response data
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Response",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 50,
"outputTokens": 100,
"totalTokens": 150,
},
}
mock_response.status_code = 200
client = AsyncHTTPHandler()
# Mock the verbose_logger to capture warning calls
with patch(
"litellm.integrations.anthropic_cache_control_hook.verbose_logger"
) as mock_logger:
with patch.object(client, "post", return_value=mock_response) as mock_post:
messages = [
{"role": "user", "content": "Single message"},
]
await litellm.acompletion(
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
messages=messages,
cache_control_injection_points=[
{
"location": "message",
"index": -5,
} # Negative out of bounds index
],
client=client,
)
# Verify that warning was called with the expected message
mock_logger.warning.assert_called_once()
warning_call = mock_logger.warning.call_args[0][0]
# Check that the warning message contains the original negative index
assert (
"AnthropicCacheControlHook: Provided index -5 is out of bounds"
in warning_call
)
assert "message list of length 1" in warning_call
assert (
"Targeted index was -4" in warning_call
) # -5 + 1 = -4 (converted index)
assert "Skipping cache control injection for this point" in warning_call
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_multiple_user_messages():
"""
Test cache control injection on multiple user messages specifically.
Note: Bedrock API combines consecutive user messages into a single message with multiple content blocks.
"""
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
# Mock response data
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Response",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 100,
"outputTokens": 200,
"totalTokens": 300,
},
}
mock_response.status_code = 200
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
# Test with multiple user messages and negative indices
response = await litellm.acompletion(
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
messages=[
{
"role": "user",
"content": "First user message.",
},
{
"role": "user",
"content": "Second user message.",
},
{
"role": "user",
"content": "Third user message that should be cached.",
},
],
cache_control_injection_points=[
{
"location": "message",
"index": -1, # Should target the last message (index 2)
},
{
"location": "message",
"index": -2, # Should target the second-to-last message (index 1)
},
],
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
print(
"Multiple user messages request_body: ",
json.dumps(request_body, indent=4),
)
# Bedrock API combines consecutive user messages into a single message
assert len(request_body["messages"]) == 1
# The combined message should have multiple content blocks with cache control
combined_message_content = request_body["messages"][0]["content"]
assert isinstance(combined_message_content, list)
# Count cache control points - should have 2 since both injection points were applied
cache_control_count = sum(
1
for item in combined_message_content
if isinstance(item, dict) and "cachePoint" in item
)
assert cache_control_count == 2
print(
f"Found {cache_control_count} cache control points in the combined message"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("bad_index", [10, -10])
async def test_anthropic_cache_control_hook_out_of_bounds(bad_index):
"""
Verify the hook does not raise an error and makes no changes
when an out-of-bounds index is provided.
"""
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
# Mock response data
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Response",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 50,
"outputTokens": 100,
"totalTokens": 150,
},
}
mock_response.status_code = 200
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
messages = [
{"role": "user", "content": "Message 1"},
{"role": "user", "content": "Message 2"},
]
await litellm.acompletion(
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
messages=messages,
cache_control_injection_points=[
{"location": "message", "index": bad_index}
],
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
# Assert that NO cache control was applied to any message
for msg in request_body["messages"]:
content = msg.get("content", [])
if isinstance(content, list):
assert not any(
"cachePoint" in item
for item in content
if isinstance(item, dict)
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"message_list",
[
[{"role": "user", "content": "Single message"}]
], # Single message only - empty list will fail at API level
)
async def test_anthropic_cache_control_hook_single_message(message_list):
"""
Verify the hook runs without error on very short message lists.
"""
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
# Mock response data
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Response",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 50,
"outputTokens": 100,
"totalTokens": 150,
},
}
mock_response.status_code = 200
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
await litellm.acompletion(
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
messages=message_list,
cache_control_injection_points=[{"location": "message", "index": -1}],
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
# For the single message, verify cache control was applied
content = request_body["messages"][0]["content"]
assert isinstance(content, list)
assert any(
"cachePoint" in item for item in content if isinstance(item, dict)
)
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_empty_message_list():
"""
Verify that empty message lists are handled appropriately (should fail at API level, not hook level).
"""
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
# This should fail at the API level, not the hook level
with pytest.raises(
litellm.BadRequestError,
match="bedrock requires at least one non-system message",
):
await litellm.acompletion(
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
messages=[],
cache_control_injection_points=[
{"location": "message", "index": -1}
],
client=client,
)
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_no_op():
"""
Verify that if no injection points are specified, messages remain unmodified.
"""
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
# Mock response data
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Response",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 50,
"outputTokens": 100,
"totalTokens": 150,
},
}
mock_response.status_code = 200
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
messages = [
{"role": "user", "content": "Message 1"},
{"role": "user", "content": "Message 2"},
]
await litellm.acompletion(
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
messages=messages,
# No cache_control_injection_points parameter
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
# Assert that NO cache control was applied
for msg in request_body["messages"]:
content = msg.get("content", [])
if isinstance(content, list):
assert not any(
"cachePoint" in item
for item in content
if isinstance(item, dict)
)
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_multiple_content_items_last_only():
"""
Test that cache_control is only applied to the last content item in a list, not all items.
This verifies the fix for https://github.com/BerriAI/litellm/issues/15696
"""
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Response",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 100,
"outputTokens": 200,
"totalTokens": 300,
},
}
mock_response.status_code = 200
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
response = await litellm.acompletion(
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "First piece of context"},
{"type": "text", "text": "Second piece of context"},
{"type": "text", "text": "Third piece of context"},
{"type": "text", "text": "Fourth piece of context"},
{"type": "text", "text": "Fifth piece of context - should be cached"},
],
}
],
cache_control_injection_points=[
{"location": "message", "index": -1}
],
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
print("Multi-content request_body: ", json.dumps(request_body, indent=4))
message_content = request_body["messages"][0]["content"]
assert isinstance(message_content, list)
cache_control_count = sum(
1
for item in message_content
if isinstance(item, dict) and "cachePoint" in item
)
assert cache_control_count == 1, f"Expected exactly 1 cache control point, found {cache_control_count}. This test verifies the fix for issue 15696 where cache_control was incorrectly applied to ALL content items."
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_document_analysis_multiple_pages():
"""
Test cache_control with multiple document pages to ensure only the last page gets cached.
This simulates document analysis with 6 content blocks, verifying the fix for issue 15696.
"""
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Summary",
}
},
"stopReason": "stop_sequence",
"usage": {
"inputTokens": 100,
"outputTokens": 200,
"totalTokens": 300,
},
}
mock_response.status_code = 200
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
response = await litellm.acompletion(
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Summarize this document"},
{"type": "text", "text": "Page 1 content"},
{"type": "text", "text": "Page 2 content"},
{"type": "text", "text": "Page 3 content"},
{"type": "text", "text": "Page 4 content"},
{"type": "text", "text": "Page 5 content - final page to cache"},
],
}
],
cache_control_injection_points=[
{"location": "message", "role": "user"}
],
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
print("Document analysis request_body: ", json.dumps(request_body, indent=4))
message_content = request_body["messages"][0]["content"]
assert isinstance(message_content, list)
cache_control_count = sum(
1
for item in message_content
if isinstance(item, dict) and "cachePoint" in item
)
assert cache_control_count == 1, f"Expected exactly 1 cache control point (last item only), found {cache_control_count}. Before fix, this would be 6 (one for each content item)."
def test_gemini_cache_control_injection_points_detected():
"""
Test that cache_control_injection_points work for Gemini models.
Verifies the full flow:
1. The hook injects cache_control markers on string-content messages
2. is_cached_message() detects the injected markers (message-level cache_control)
3. separate_cached_messages() correctly separates the messages
Fixes GitHub issue #18519.
"""
from litellm.llms.vertex_ai.context_caching.transformation import (
separate_cached_messages,
)
from litellm.utils import is_cached_message
hook = AnthropicCacheControlHook()
# Simulate messages as they would appear for a Gemini call with string content
messages: List[AllMessageValues] = [
{
"role": "system",
"content": "You are a helpful assistant that analyzes legal documents.",
},
{
"role": "user",
"content": "What are the key terms?",
},
]
# Simulate what the hook does: inject cache_control on the system message
injection_points = [{"location": "message", "role": "system"}]
# Manually apply the hook's logic for the system message (string content case)
# The hook sets message["cache_control"] = {"type": "ephemeral"} for string content
hook._safe_insert_cache_control_in_message(
message=messages[0],
control={"type": "ephemeral"},
)
# Verify the hook injected message-level cache_control (string content path)
assert messages[0].get("cache_control") == {"type": "ephemeral"}
# Verify is_cached_message detects message-level cache_control
assert is_cached_message(messages[0]) is True
assert is_cached_message(messages[1]) is False
# Verify separate_cached_messages correctly separates them
cached, non_cached = separate_cached_messages(messages)
assert len(cached) == 1
assert cached[0]["role"] == "system"
assert len(non_cached) == 1
assert non_cached[0]["role"] == "user"
def test_gemini_cache_control_injection_list_content_detected():
"""
Test that cache_control_injection_points work for Gemini models
when the message content is a list (not string).
"""
from litellm.llms.vertex_ai.context_caching.transformation import (
separate_cached_messages,
)
from litellm.utils import is_cached_message
hook = AnthropicCacheControlHook()
messages: List[AllMessageValues] = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."},
{"type": "text", "text": "Analyze legal documents carefully."},
],
},
{
"role": "user",
"content": "What are the key terms?",
},
]
# Apply the hook's logic for list content - sets cache_control on last item
hook._safe_insert_cache_control_in_message(
message=messages[0],
control={"type": "ephemeral"},
)
# Verify cache_control was set on the last content item
assert messages[0]["content"][-1]["cache_control"] == {"type": "ephemeral"}
# Verify is_cached_message detects content-item-level cache_control
assert is_cached_message(messages[0]) is True
assert is_cached_message(messages[1]) is False
# Verify separate_cached_messages correctly separates them
cached, non_cached = separate_cached_messages(messages)
assert len(cached) == 1
assert len(non_cached) == 1
@pytest.mark.asyncio
async def test_anthropic_cache_control_hook_string_negative_index():
"""
Test that string negative indices like "-1" are handled correctly.
When cache_control_injection_points are stored in DB/config as JSON, indices
like -1 become the string "-1". Previously, str.isdigit() returned False for
"-1" so the cache control was silently skipped. This tests the fix.
"""
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
"AWS_REGION_NAME": "us-west-2",
},
):
anthropic_cache_control_hook = AnthropicCacheControlHook()
litellm.callbacks = [anthropic_cache_control_hook]
mock_response = MagicMock()
mock_response.json.return_value = {
"output": {
"message": {
"role": "assistant",
"content": "Response",
}
},
"stopReason": "end_turn",
"usage": {
"inputTokens": 100,
"outputTokens": 50,
"totalTokens": 150,
},
}
mock_response.status_code = 200
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
await litellm.acompletion(
model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
messages=[
{"role": "user", "content": "First message"},
{"role": "assistant", "content": "First response"},
{"role": "user", "content": "Second message"},
],
# index is a string "-1" (as stored in DB/config JSON)
cache_control_injection_points=[
{"location": "message", "index": "-1"},
],
client=client,
)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
# The last user message should have cache control applied
last_message = request_body["messages"][-1]
last_message_content = last_message["content"]
assert isinstance(last_message_content, list), (
f"Expected list content, got {type(last_message_content)}"
)
has_cache_point = any(
isinstance(item, dict) and "cachePoint" in item
for item in last_message_content
)
assert has_cache_point, (
f"Expected cachePoint in last message content, got: {last_message_content}. "
"String index '-1' was not parsed correctly (str.isdigit() returns False for negative strings)."
)