feat(gemini): Add custom TTL support for context caching (#9810) (#12541)

- Add ttl parameter to cache_control for Gemini models
- Support Google's TTL format (e.g., '3600s', '7200s')
- Implement robust TTL extraction and validation
- Extract TTL before system message transformation to handle all cases
- Add comprehensive test suite with 17 test cases in tests/test_litellm/
- Update documentation with TTL usage examples
- Maintain backward compatibility with existing cache_control usage

Fixes #9810
This commit is contained in:
Marcelo Díaz
2025-07-15 01:30:54 -04:00
committed by GitHub
parent f05ec34e11
commit 094ce8f772
4 changed files with 727 additions and 5 deletions
+142 -4
View File
@@ -1219,12 +1219,38 @@ Use Google AI Studio context caching is supported by
in your message content block.
### Custom TTL Support
You can now specify a custom Time-To-Live (TTL) for your cached content using the `ttl` parameter:
```bash
{
{
"role": "system",
"content": ...,
"cache_control": {
"type": "ephemeral",
"ttl": "3600s" # 👈 Cache for 1 hour
}
},
...
}
```
**TTL Format Requirements:**
- Must be a string ending with 's' for seconds
- Must contain a positive number (can be decimal)
- Examples: `"3600s"` (1 hour), `"7200s"` (2 hours), `"1800s"` (30 minutes), `"1.5s"` (1.5 seconds)
**TTL Behavior:**
- If multiple cached messages have different TTLs, the first valid TTL encountered will be used
- Invalid TTL formats are ignored and the cache will use Google's default expiration time
- If no TTL is specified, Google's default cache expiration (approximately 1 hour) applies
### Architecture Diagram
<Image img={require('../../img/gemini_context_caching.png')} />
**Notes:**
- [Relevant code](https://github.com/BerriAI/litellm/blob/main/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py#L255)
@@ -1233,7 +1259,6 @@ in your message content block.
- If multiple non-continuous blocks contain `cache_control` - the first continuous block will be used. (sent to `/cachedContent` in the [Gemini format](https://ai.google.dev/api/caching#cache_create-SHELL))
- The raw request to Gemini's `/generateContent` endpoint looks like this:
```bash
@@ -1253,7 +1278,6 @@ curl -X POST "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5
```
### Example Usage
<Tabs>
@@ -1293,6 +1317,48 @@ for _ in range(2):
print(resp.usage) # 👈 2nd usage block will be less, since cached tokens used
```
</TabItem>
<TabItem value="sdk-ttl" label="SDK with Custom TTL">
```python
from litellm import completion
# Cache for 2 hours (7200 seconds)
resp = completion(
model="gemini/gemini-1.5-pro",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {
"type": "ephemeral",
"ttl": "7200s" # 👈 Cache for 2 hours
},
}
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {
"type": "ephemeral",
"ttl": "3600s" # 👈 This TTL will be ignored (first one is used)
},
}
],
}
]
)
print(resp.usage)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
@@ -1350,6 +1416,44 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
}'
```
</TabItem>
<TabItem value="curl-ttl" label="Curl with Custom TTL">
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
"model": "gemini-1.5-pro",
"messages": [
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {
"type": "ephemeral",
"ttl": "7200s"
}
}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {
"type": "ephemeral",
"ttl": "3600s"
}
}
]
}
]
}'
```
</TabItem>
<TabItem value="openai-python" label="OpenAI Python SDK">
```python
@@ -1382,6 +1486,40 @@ response = await client.chat.completions.create(
```
</TabItem>
<TabItem value="openai-python-ttl" label="OpenAI Python SDK with TTL">
```python
import openai
client = openai.AsyncOpenAI(
api_key="anything", # litellm proxy api key
base_url="http://0.0.0.0:4000" # litellm proxy base url
)
response = await client.chat.completions.create(
model="gemini-1.5-pro",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 4000,
"cache_control": {
"type": "ephemeral",
"ttl": "7200s" # Cache for 2 hours
}
}
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
]
)
```
</TabItem>
</Tabs>
@@ -4,7 +4,8 @@ Transformation logic for context caching.
Why separate file? Make it easy to see how transformation works
"""
from typing import List, Tuple
import re
from typing import List, Optional, Tuple
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import CachedContentRequestBody
@@ -47,6 +48,72 @@ def get_first_continuous_block_idx(
return len(filtered_messages) - 1
def extract_ttl_from_cached_messages(messages: List[AllMessageValues]) -> Optional[str]:
"""
Extract TTL from cached messages. Returns the first valid TTL found.
Args:
messages: List of messages to extract TTL from
Returns:
Optional[str]: TTL string in format "3600s" or None if not found/invalid
"""
for message in messages:
if not is_cached_message(message):
continue
content = message.get("content")
if not content or isinstance(content, str):
continue
for content_item in content:
# Type check to ensure content_item is a dictionary before calling .get()
if not isinstance(content_item, dict):
continue
cache_control = content_item.get("cache_control")
if not cache_control or not isinstance(cache_control, dict):
continue
if cache_control.get("type") != "ephemeral":
continue
ttl = cache_control.get("ttl")
if ttl and _is_valid_ttl_format(ttl):
return str(ttl)
return None
def _is_valid_ttl_format(ttl: str) -> bool:
"""
Validate TTL format. Should be a string ending with 's' for seconds.
Examples: "3600s", "7200s", "1.5s"
Args:
ttl: TTL string to validate
Returns:
bool: True if valid format, False otherwise
"""
if not isinstance(ttl, str):
return False
# TTL should end with 's' and contain a valid number before it
pattern = r'^([0-9]*\.?[0-9]+)s$'
match = re.match(pattern, ttl)
if not match:
return False
try:
# Ensure the numeric part is valid and positive
numeric_part = float(match.group(1))
return numeric_part > 0
except ValueError:
return False
def separate_cached_messages(
messages: List[AllMessageValues],
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
@@ -90,6 +157,9 @@ def separate_cached_messages(
def transform_openai_messages_to_gemini_context_caching(
model: str, messages: List[AllMessageValues], cache_key: str
) -> CachedContentRequestBody:
# Extract TTL from cached messages BEFORE system message transformation
ttl = extract_ttl_from_cached_messages(messages)
supports_system_message = get_supports_system_message(
model=model, custom_llm_provider="gemini"
)
@@ -99,11 +169,17 @@ def transform_openai_messages_to_gemini_context_caching(
)
transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
data = CachedContentRequestBody(
contents=transformed_messages,
model="models/{}".format(model),
displayName=cache_key,
)
# Add TTL if present and valid
if ttl:
data["ttl"] = ttl
if transformed_system_messages is not None:
data["system_instruction"] = transformed_system_messages
+148
View File
@@ -12,6 +12,7 @@ sys.path.insert(
from base_llm_unit_tests import BaseLLMChatTest
from litellm.llms.vertex_ai.context_caching.transformation import (
separate_cached_messages,
transform_openai_messages_to_gemini_context_caching,
)
import litellm
from litellm import completion
@@ -67,6 +68,153 @@ class TestGoogleAIStudioGemini(BaseLLMChatTest):
print(f"response={response}")
def test_gemini_context_caching_with_ttl():
"""Test Gemini context caching with TTL support"""
# Test case 1: Basic TTL functionality
messages_with_ttl = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 400,
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
}
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral", "ttl": "7200s"},
}
],
}
]
# Test the transformation function directly
result = transform_openai_messages_to_gemini_context_caching(
model="gemini-1.5-pro",
messages=messages_with_ttl,
cache_key="test-ttl-cache-key"
)
# Verify TTL is properly included in the result
assert "ttl" in result
assert result["ttl"] == "3600s" # Should use the first valid TTL found
assert result["model"] == "models/gemini-1.5-pro"
assert result["displayName"] == "test-ttl-cache-key"
# Test case 2: Invalid TTL should be ignored
messages_invalid_ttl = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Cached content with invalid TTL",
"cache_control": {"type": "ephemeral", "ttl": "invalid_ttl"},
}
],
}
]
result_invalid = transform_openai_messages_to_gemini_context_caching(
model="gemini-1.5-pro",
messages=messages_invalid_ttl,
cache_key="test-invalid-ttl"
)
# Verify invalid TTL is not included
assert "ttl" not in result_invalid
assert result_invalid["model"] == "models/gemini-1.5-pro"
assert result_invalid["displayName"] == "test-invalid-ttl"
# Test case 3: Messages without TTL should work normally
messages_no_ttl = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Cached content without TTL",
"cache_control": {"type": "ephemeral"},
}
],
}
]
result_no_ttl = transform_openai_messages_to_gemini_context_caching(
model="gemini-1.5-pro",
messages=messages_no_ttl,
cache_key="test-no-ttl"
)
# Verify no TTL field is present when not specified
assert "ttl" not in result_no_ttl
assert result_no_ttl["model"] == "models/gemini-1.5-pro"
assert result_no_ttl["displayName"] == "test-no-ttl"
# Test case 4: Mixed messages with some having TTL
messages_mixed = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "System message with TTL",
"cache_control": {"type": "ephemeral", "ttl": "1800s"},
}
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "User message without TTL",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Assistant response without cache control"
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Another user message",
"cache_control": {"type": "ephemeral", "ttl": "900s"},
}
],
}
]
# Test separation of cached messages
cached_messages, non_cached_messages = separate_cached_messages(messages_mixed)
assert len(cached_messages) > 0
assert len(non_cached_messages) > 0
# Test transformation with mixed messages
result_mixed = transform_openai_messages_to_gemini_context_caching(
model="gemini-1.5-pro",
messages=messages_mixed,
cache_key="test-mixed-ttl"
)
# Should pick up the first valid TTL
assert "ttl" in result_mixed
assert result_mixed["ttl"] == "1800s"
assert result_mixed["model"] == "models/gemini-1.5-pro"
assert result_mixed["displayName"] == "test-mixed-ttl"
def test_gemini_context_caching_separate_messages():
messages = [
# System Message
@@ -0,0 +1,360 @@
import pytest
from litellm.llms.vertex_ai.context_caching.transformation import (
extract_ttl_from_cached_messages,
_is_valid_ttl_format,
transform_openai_messages_to_gemini_context_caching,
)
class TestTTLValidation:
"""Test TTL format validation"""
def test_valid_ttl_formats(self):
"""Test various valid TTL formats"""
valid_ttls = [
"3600s",
"1s",
"7200s",
"1.5s",
"0.1s",
"86400s",
"123.456s"
]
for ttl in valid_ttls:
assert _is_valid_ttl_format(ttl), f"TTL {ttl} should be valid"
def test_invalid_ttl_formats(self):
"""Test various invalid TTL formats"""
invalid_ttls = [
"3600", # missing 's'
"s", # missing number
"-1s", # negative number
"0s", # zero
"3600m", # wrong unit
"abc.s", # invalid number
"", # empty string
"3600.s", # invalid decimal
"3600 s", # space
"3600ss", # extra 's'
None, # None
123, # not a string
]
for ttl in invalid_ttls:
assert not _is_valid_ttl_format(ttl), f"TTL {ttl} should be invalid"
class TestTTLExtraction:
"""Test TTL extraction from cached messages"""
def test_extract_ttl_from_single_message(self):
"""Test extracting TTL from a single cached message"""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is cached content",
"cache_control": {"type": "ephemeral", "ttl": "3600s"}
}
]
}
]
ttl = extract_ttl_from_cached_messages(messages)
assert ttl == "3600s"
def test_extract_ttl_from_multiple_messages(self):
"""Test extracting TTL from multiple cached messages (should return first valid one)"""
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "System message",
"cache_control": {"type": "ephemeral", "ttl": "7200s"}
}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "User message",
"cache_control": {"type": "ephemeral", "ttl": "3600s"}
}
]
}
]
ttl = extract_ttl_from_cached_messages(messages)
assert ttl == "7200s" # Should return the first valid TTL found
def test_extract_ttl_no_cache_control(self):
"""Test extracting TTL from messages without cache_control"""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Regular message without cache control"
}
]
}
]
ttl = extract_ttl_from_cached_messages(messages)
assert ttl is None
def test_extract_ttl_invalid_format(self):
"""Test extracting TTL with invalid format"""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Cached content with invalid TTL",
"cache_control": {"type": "ephemeral", "ttl": "invalid"}
}
]
}
]
ttl = extract_ttl_from_cached_messages(messages)
assert ttl is None
def test_extract_ttl_missing_ttl_field(self):
"""Test extracting TTL when ttl field is missing"""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Cached content without TTL field",
"cache_control": {"type": "ephemeral"}
}
]
}
]
ttl = extract_ttl_from_cached_messages(messages)
assert ttl is None
def test_extract_ttl_mixed_valid_invalid(self):
"""Test extracting TTL when some messages have valid TTL and others don't"""
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "System message with invalid TTL",
"cache_control": {"type": "ephemeral", "ttl": "invalid"}
}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "User message with valid TTL",
"cache_control": {"type": "ephemeral", "ttl": "3600s"}
}
]
}
]
ttl = extract_ttl_from_cached_messages(messages)
assert ttl == "3600s" # Should return the first valid TTL found
def test_extract_ttl_string_content(self):
"""Test extracting TTL when message content is a string (not a list)"""
messages = [
{
"role": "user",
"content": "String content"
}
]
ttl = extract_ttl_from_cached_messages(messages)
assert ttl is None
class TestTransformationWithTTL:
"""Test the complete transformation with TTL support"""
def test_transform_with_valid_ttl(self):
"""Test transformation includes TTL when provided"""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Cached content",
"cache_control": {"type": "ephemeral", "ttl": "3600s"}
}
]
}
]
result = transform_openai_messages_to_gemini_context_caching(
model="gemini-1.5-pro",
messages=messages,
cache_key="test-cache-key"
)
assert "ttl" in result
assert result["ttl"] == "3600s"
assert result["model"] == "models/gemini-1.5-pro"
assert result["displayName"] == "test-cache-key"
def test_transform_without_ttl(self):
"""Test transformation without TTL"""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Cached content",
"cache_control": {"type": "ephemeral"}
}
]
}
]
result = transform_openai_messages_to_gemini_context_caching(
model="gemini-1.5-pro",
messages=messages,
cache_key="test-cache-key"
)
assert "ttl" not in result
assert result["model"] == "models/gemini-1.5-pro"
assert result["displayName"] == "test-cache-key"
def test_transform_with_invalid_ttl(self):
"""Test transformation with invalid TTL (should be ignored)"""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Cached content",
"cache_control": {"type": "ephemeral", "ttl": "invalid"}
}
]
}
]
result = transform_openai_messages_to_gemini_context_caching(
model="gemini-1.5-pro",
messages=messages,
cache_key="test-cache-key"
)
assert "ttl" not in result
assert result["model"] == "models/gemini-1.5-pro"
assert result["displayName"] == "test-cache-key"
def test_transform_with_system_message_and_ttl(self):
"""Test transformation with system message and TTL"""
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "System instruction",
"cache_control": {"type": "ephemeral", "ttl": "7200s"}
}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "User message"
}
]
}
]
result = transform_openai_messages_to_gemini_context_caching(
model="gemini-1.5-pro",
messages=messages,
cache_key="test-cache-key"
)
assert "ttl" in result
assert result["ttl"] == "7200s"
assert "system_instruction" in result
assert result["model"] == "models/gemini-1.5-pro"
assert result["displayName"] == "test-cache-key"
class TestEdgeCases:
"""Test edge cases and error conditions"""
def test_ttl_extraction_empty_messages(self):
"""Test TTL extraction with empty message list"""
messages = []
ttl = extract_ttl_from_cached_messages(messages)
assert ttl is None
def test_ttl_extraction_none_content(self):
"""Test TTL extraction when content is None"""
messages = [
{
"role": "user",
"content": None
}
]
ttl = extract_ttl_from_cached_messages(messages)
assert ttl is None
def test_ttl_extraction_empty_content_list(self):
"""Test TTL extraction when content list is empty"""
messages = [
{
"role": "user",
"content": []
}
]
ttl = extract_ttl_from_cached_messages(messages)
assert ttl is None
def test_ttl_validation_type_conversion(self):
"""Test TTL validation handles type conversion properly"""
# Test that numeric TTL gets converted to string
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Cached content",
"cache_control": {"type": "ephemeral", "ttl": "3600s"}
}
]
}
]
ttl = extract_ttl_from_cached_messages(messages)
assert isinstance(ttl, str)
assert ttl == "3600s"
if __name__ == "__main__":
pytest.main([__file__, "-v"])