diff --git a/docs/my-website/docs/providers/gemini.md b/docs/my-website/docs/providers/gemini.md
index 0d388a4151..9376144cc8 100644
--- a/docs/my-website/docs/providers/gemini.md
+++ b/docs/my-website/docs/providers/gemini.md
@@ -1219,12 +1219,38 @@ Use Google AI Studio context caching is supported by
in your message content block.
+### Custom TTL Support
+
+You can now specify a custom Time-To-Live (TTL) for your cached content using the `ttl` parameter:
+
+```bash
+{
+ {
+ "role": "system",
+ "content": ...,
+ "cache_control": {
+ "type": "ephemeral",
+ "ttl": "3600s" # 👈 Cache for 1 hour
+ }
+ },
+ ...
+}
+```
+
+**TTL Format Requirements:**
+- Must be a string ending with 's' for seconds
+- Must contain a positive number (can be decimal)
+- Examples: `"3600s"` (1 hour), `"7200s"` (2 hours), `"1800s"` (30 minutes), `"1.5s"` (1.5 seconds)
+
+**TTL Behavior:**
+- If multiple cached messages have different TTLs, the first valid TTL encountered will be used
+- Invalid TTL formats are ignored and the cache will use Google's default expiration time
+- If no TTL is specified, Google's default cache expiration (approximately 1 hour) applies
+
### Architecture Diagram
-
-
**Notes:**
- [Relevant code](https://github.com/BerriAI/litellm/blob/main/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py#L255)
@@ -1233,7 +1259,6 @@ in your message content block.
- If multiple non-continuous blocks contain `cache_control` - the first continuous block will be used. (sent to `/cachedContent` in the [Gemini format](https://ai.google.dev/api/caching#cache_create-SHELL))
-
- The raw request to Gemini's `/generateContent` endpoint looks like this:
```bash
@@ -1253,7 +1278,6 @@ curl -X POST "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5
```
-
### Example Usage
@@ -1293,6 +1317,48 @@ for _ in range(2):
print(resp.usage) # 👈 2nd usage block will be less, since cached tokens used
```
+
+
+
+```python
+from litellm import completion
+
+# Cache for 2 hours (7200 seconds)
+resp = completion(
+ model="gemini/gemini-1.5-pro",
+ messages=[
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement" * 4000,
+ "cache_control": {
+ "type": "ephemeral",
+ "ttl": "7200s" # 👈 Cache for 2 hours
+ },
+ }
+ ],
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the key terms and conditions in this agreement?",
+ "cache_control": {
+ "type": "ephemeral",
+ "ttl": "3600s" # 👈 This TTL will be ignored (first one is used)
+ },
+ }
+ ],
+ }
+ ]
+)
+
+print(resp.usage)
+```
+
@@ -1350,6 +1416,44 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
}'
```
+
+
+```bash
+curl --location 'http://0.0.0.0:4000/chat/completions' \
+ --header 'Content-Type: application/json' \
+ --data '{
+ "model": "gemini-1.5-pro",
+ "messages": [
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement" * 4000,
+ "cache_control": {
+ "type": "ephemeral",
+ "ttl": "7200s"
+ }
+ }
+ ]
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the key terms and conditions in this agreement?",
+ "cache_control": {
+ "type": "ephemeral",
+ "ttl": "3600s"
+ }
+ }
+ ]
+ }
+ ]
+}'
+```
+
```python
@@ -1382,6 +1486,40 @@ response = await client.chat.completions.create(
```
+
+
+
+```python
+import openai
+client = openai.AsyncOpenAI(
+ api_key="anything", # litellm proxy api key
+ base_url="http://0.0.0.0:4000" # litellm proxy base url
+)
+
+response = await client.chat.completions.create(
+ model="gemini-1.5-pro",
+ messages=[
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement" * 4000,
+ "cache_control": {
+ "type": "ephemeral",
+ "ttl": "7200s" # Cache for 2 hours
+ }
+ }
+ ],
+ },
+ {
+ "role": "user",
+ "content": "what are the key terms and conditions in this agreement?",
+ },
+ ]
+)
+```
+
diff --git a/litellm/llms/vertex_ai/context_caching/transformation.py b/litellm/llms/vertex_ai/context_caching/transformation.py
index 83c15029b2..f3ca699546 100644
--- a/litellm/llms/vertex_ai/context_caching/transformation.py
+++ b/litellm/llms/vertex_ai/context_caching/transformation.py
@@ -4,7 +4,8 @@ Transformation logic for context caching.
Why separate file? Make it easy to see how transformation works
"""
-from typing import List, Tuple
+import re
+from typing import List, Optional, Tuple
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import CachedContentRequestBody
@@ -47,6 +48,72 @@ def get_first_continuous_block_idx(
return len(filtered_messages) - 1
+def extract_ttl_from_cached_messages(messages: List[AllMessageValues]) -> Optional[str]:
+ """
+ Extract TTL from cached messages. Returns the first valid TTL found.
+
+ Args:
+ messages: List of messages to extract TTL from
+
+ Returns:
+ Optional[str]: TTL string in format "3600s" or None if not found/invalid
+ """
+ for message in messages:
+ if not is_cached_message(message):
+ continue
+
+ content = message.get("content")
+ if not content or isinstance(content, str):
+ continue
+
+ for content_item in content:
+ # Type check to ensure content_item is a dictionary before calling .get()
+ if not isinstance(content_item, dict):
+ continue
+
+ cache_control = content_item.get("cache_control")
+ if not cache_control or not isinstance(cache_control, dict):
+ continue
+
+ if cache_control.get("type") != "ephemeral":
+ continue
+
+ ttl = cache_control.get("ttl")
+ if ttl and _is_valid_ttl_format(ttl):
+ return str(ttl)
+
+ return None
+
+
+def _is_valid_ttl_format(ttl: str) -> bool:
+ """
+ Validate TTL format. Should be a string ending with 's' for seconds.
+ Examples: "3600s", "7200s", "1.5s"
+
+ Args:
+ ttl: TTL string to validate
+
+ Returns:
+ bool: True if valid format, False otherwise
+ """
+ if not isinstance(ttl, str):
+ return False
+
+ # TTL should end with 's' and contain a valid number before it
+ pattern = r'^([0-9]*\.?[0-9]+)s$'
+ match = re.match(pattern, ttl)
+
+ if not match:
+ return False
+
+ try:
+ # Ensure the numeric part is valid and positive
+ numeric_part = float(match.group(1))
+ return numeric_part > 0
+ except ValueError:
+ return False
+
+
def separate_cached_messages(
messages: List[AllMessageValues],
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
@@ -90,6 +157,9 @@ def separate_cached_messages(
def transform_openai_messages_to_gemini_context_caching(
model: str, messages: List[AllMessageValues], cache_key: str
) -> CachedContentRequestBody:
+ # Extract TTL from cached messages BEFORE system message transformation
+ ttl = extract_ttl_from_cached_messages(messages)
+
supports_system_message = get_supports_system_message(
model=model, custom_llm_provider="gemini"
)
@@ -99,11 +169,17 @@ def transform_openai_messages_to_gemini_context_caching(
)
transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
+
data = CachedContentRequestBody(
contents=transformed_messages,
model="models/{}".format(model),
displayName=cache_key,
)
+
+ # Add TTL if present and valid
+ if ttl:
+ data["ttl"] = ttl
+
if transformed_system_messages is not None:
data["system_instruction"] = transformed_system_messages
diff --git a/tests/llm_translation/test_gemini.py b/tests/llm_translation/test_gemini.py
index c41d914e53..ae279bbde4 100644
--- a/tests/llm_translation/test_gemini.py
+++ b/tests/llm_translation/test_gemini.py
@@ -12,6 +12,7 @@ sys.path.insert(
from base_llm_unit_tests import BaseLLMChatTest
from litellm.llms.vertex_ai.context_caching.transformation import (
separate_cached_messages,
+ transform_openai_messages_to_gemini_context_caching,
)
import litellm
from litellm import completion
@@ -67,6 +68,153 @@ class TestGoogleAIStudioGemini(BaseLLMChatTest):
print(f"response={response}")
+def test_gemini_context_caching_with_ttl():
+ """Test Gemini context caching with TTL support"""
+
+ # Test case 1: Basic TTL functionality
+ messages_with_ttl = [
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "Here is the full text of a complex legal agreement" * 400,
+ "cache_control": {"type": "ephemeral", "ttl": "3600s"},
+ }
+ ],
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What are the key terms and conditions in this agreement?",
+ "cache_control": {"type": "ephemeral", "ttl": "7200s"},
+ }
+ ],
+ }
+ ]
+
+ # Test the transformation function directly
+ result = transform_openai_messages_to_gemini_context_caching(
+ model="gemini-1.5-pro",
+ messages=messages_with_ttl,
+ cache_key="test-ttl-cache-key"
+ )
+
+ # Verify TTL is properly included in the result
+ assert "ttl" in result
+ assert result["ttl"] == "3600s" # Should use the first valid TTL found
+ assert result["model"] == "models/gemini-1.5-pro"
+ assert result["displayName"] == "test-ttl-cache-key"
+
+ # Test case 2: Invalid TTL should be ignored
+ messages_invalid_ttl = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Cached content with invalid TTL",
+ "cache_control": {"type": "ephemeral", "ttl": "invalid_ttl"},
+ }
+ ],
+ }
+ ]
+
+ result_invalid = transform_openai_messages_to_gemini_context_caching(
+ model="gemini-1.5-pro",
+ messages=messages_invalid_ttl,
+ cache_key="test-invalid-ttl"
+ )
+
+ # Verify invalid TTL is not included
+ assert "ttl" not in result_invalid
+ assert result_invalid["model"] == "models/gemini-1.5-pro"
+ assert result_invalid["displayName"] == "test-invalid-ttl"
+
+ # Test case 3: Messages without TTL should work normally
+ messages_no_ttl = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Cached content without TTL",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ }
+ ]
+
+ result_no_ttl = transform_openai_messages_to_gemini_context_caching(
+ model="gemini-1.5-pro",
+ messages=messages_no_ttl,
+ cache_key="test-no-ttl"
+ )
+
+ # Verify no TTL field is present when not specified
+ assert "ttl" not in result_no_ttl
+ assert result_no_ttl["model"] == "models/gemini-1.5-pro"
+ assert result_no_ttl["displayName"] == "test-no-ttl"
+
+ # Test case 4: Mixed messages with some having TTL
+ messages_mixed = [
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "System message with TTL",
+ "cache_control": {"type": "ephemeral", "ttl": "1800s"},
+ }
+ ],
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "User message without TTL",
+ "cache_control": {"type": "ephemeral"},
+ }
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": "Assistant response without cache control"
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Another user message",
+ "cache_control": {"type": "ephemeral", "ttl": "900s"},
+ }
+ ],
+ }
+ ]
+
+ # Test separation of cached messages
+ cached_messages, non_cached_messages = separate_cached_messages(messages_mixed)
+ assert len(cached_messages) > 0
+ assert len(non_cached_messages) > 0
+
+ # Test transformation with mixed messages
+ result_mixed = transform_openai_messages_to_gemini_context_caching(
+ model="gemini-1.5-pro",
+ messages=messages_mixed,
+ cache_key="test-mixed-ttl"
+ )
+
+ # Should pick up the first valid TTL
+ assert "ttl" in result_mixed
+ assert result_mixed["ttl"] == "1800s"
+ assert result_mixed["model"] == "models/gemini-1.5-pro"
+ assert result_mixed["displayName"] == "test-mixed-ttl"
+
+
def test_gemini_context_caching_separate_messages():
messages = [
# System Message
diff --git a/tests/test_litellm/llms/vertex_ai/context_caching/test_context_caching_ttl.py b/tests/test_litellm/llms/vertex_ai/context_caching/test_context_caching_ttl.py
new file mode 100644
index 0000000000..cce6055ab3
--- /dev/null
+++ b/tests/test_litellm/llms/vertex_ai/context_caching/test_context_caching_ttl.py
@@ -0,0 +1,360 @@
+import pytest
+from litellm.llms.vertex_ai.context_caching.transformation import (
+ extract_ttl_from_cached_messages,
+ _is_valid_ttl_format,
+ transform_openai_messages_to_gemini_context_caching,
+)
+
+
+class TestTTLValidation:
+ """Test TTL format validation"""
+
+ def test_valid_ttl_formats(self):
+ """Test various valid TTL formats"""
+ valid_ttls = [
+ "3600s",
+ "1s",
+ "7200s",
+ "1.5s",
+ "0.1s",
+ "86400s",
+ "123.456s"
+ ]
+
+ for ttl in valid_ttls:
+ assert _is_valid_ttl_format(ttl), f"TTL {ttl} should be valid"
+
+ def test_invalid_ttl_formats(self):
+ """Test various invalid TTL formats"""
+ invalid_ttls = [
+ "3600", # missing 's'
+ "s", # missing number
+ "-1s", # negative number
+ "0s", # zero
+ "3600m", # wrong unit
+ "abc.s", # invalid number
+ "", # empty string
+ "3600.s", # invalid decimal
+ "3600 s", # space
+ "3600ss", # extra 's'
+ None, # None
+ 123, # not a string
+ ]
+
+ for ttl in invalid_ttls:
+ assert not _is_valid_ttl_format(ttl), f"TTL {ttl} should be invalid"
+
+
+class TestTTLExtraction:
+ """Test TTL extraction from cached messages"""
+
+ def test_extract_ttl_from_single_message(self):
+ """Test extracting TTL from a single cached message"""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "This is cached content",
+ "cache_control": {"type": "ephemeral", "ttl": "3600s"}
+ }
+ ]
+ }
+ ]
+
+ ttl = extract_ttl_from_cached_messages(messages)
+ assert ttl == "3600s"
+
+ def test_extract_ttl_from_multiple_messages(self):
+ """Test extracting TTL from multiple cached messages (should return first valid one)"""
+ messages = [
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "System message",
+ "cache_control": {"type": "ephemeral", "ttl": "7200s"}
+ }
+ ]
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "User message",
+ "cache_control": {"type": "ephemeral", "ttl": "3600s"}
+ }
+ ]
+ }
+ ]
+
+ ttl = extract_ttl_from_cached_messages(messages)
+ assert ttl == "7200s" # Should return the first valid TTL found
+
+ def test_extract_ttl_no_cache_control(self):
+ """Test extracting TTL from messages without cache_control"""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Regular message without cache control"
+ }
+ ]
+ }
+ ]
+
+ ttl = extract_ttl_from_cached_messages(messages)
+ assert ttl is None
+
+ def test_extract_ttl_invalid_format(self):
+ """Test extracting TTL with invalid format"""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Cached content with invalid TTL",
+ "cache_control": {"type": "ephemeral", "ttl": "invalid"}
+ }
+ ]
+ }
+ ]
+
+ ttl = extract_ttl_from_cached_messages(messages)
+ assert ttl is None
+
+ def test_extract_ttl_missing_ttl_field(self):
+ """Test extracting TTL when ttl field is missing"""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Cached content without TTL field",
+ "cache_control": {"type": "ephemeral"}
+ }
+ ]
+ }
+ ]
+
+ ttl = extract_ttl_from_cached_messages(messages)
+ assert ttl is None
+
+ def test_extract_ttl_mixed_valid_invalid(self):
+ """Test extracting TTL when some messages have valid TTL and others don't"""
+ messages = [
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "System message with invalid TTL",
+ "cache_control": {"type": "ephemeral", "ttl": "invalid"}
+ }
+ ]
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "User message with valid TTL",
+ "cache_control": {"type": "ephemeral", "ttl": "3600s"}
+ }
+ ]
+ }
+ ]
+
+ ttl = extract_ttl_from_cached_messages(messages)
+ assert ttl == "3600s" # Should return the first valid TTL found
+
+ def test_extract_ttl_string_content(self):
+ """Test extracting TTL when message content is a string (not a list)"""
+ messages = [
+ {
+ "role": "user",
+ "content": "String content"
+ }
+ ]
+
+ ttl = extract_ttl_from_cached_messages(messages)
+ assert ttl is None
+
+
+class TestTransformationWithTTL:
+ """Test the complete transformation with TTL support"""
+
+ def test_transform_with_valid_ttl(self):
+ """Test transformation includes TTL when provided"""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Cached content",
+ "cache_control": {"type": "ephemeral", "ttl": "3600s"}
+ }
+ ]
+ }
+ ]
+
+ result = transform_openai_messages_to_gemini_context_caching(
+ model="gemini-1.5-pro",
+ messages=messages,
+ cache_key="test-cache-key"
+ )
+
+ assert "ttl" in result
+ assert result["ttl"] == "3600s"
+ assert result["model"] == "models/gemini-1.5-pro"
+ assert result["displayName"] == "test-cache-key"
+
+ def test_transform_without_ttl(self):
+ """Test transformation without TTL"""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Cached content",
+ "cache_control": {"type": "ephemeral"}
+ }
+ ]
+ }
+ ]
+
+ result = transform_openai_messages_to_gemini_context_caching(
+ model="gemini-1.5-pro",
+ messages=messages,
+ cache_key="test-cache-key"
+ )
+
+ assert "ttl" not in result
+ assert result["model"] == "models/gemini-1.5-pro"
+ assert result["displayName"] == "test-cache-key"
+
+ def test_transform_with_invalid_ttl(self):
+ """Test transformation with invalid TTL (should be ignored)"""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Cached content",
+ "cache_control": {"type": "ephemeral", "ttl": "invalid"}
+ }
+ ]
+ }
+ ]
+
+ result = transform_openai_messages_to_gemini_context_caching(
+ model="gemini-1.5-pro",
+ messages=messages,
+ cache_key="test-cache-key"
+ )
+
+ assert "ttl" not in result
+ assert result["model"] == "models/gemini-1.5-pro"
+ assert result["displayName"] == "test-cache-key"
+
+ def test_transform_with_system_message_and_ttl(self):
+ """Test transformation with system message and TTL"""
+ messages = [
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "System instruction",
+ "cache_control": {"type": "ephemeral", "ttl": "7200s"}
+ }
+ ]
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "User message"
+ }
+ ]
+ }
+ ]
+
+ result = transform_openai_messages_to_gemini_context_caching(
+ model="gemini-1.5-pro",
+ messages=messages,
+ cache_key="test-cache-key"
+ )
+
+ assert "ttl" in result
+ assert result["ttl"] == "7200s"
+ assert "system_instruction" in result
+ assert result["model"] == "models/gemini-1.5-pro"
+ assert result["displayName"] == "test-cache-key"
+
+
+class TestEdgeCases:
+ """Test edge cases and error conditions"""
+
+ def test_ttl_extraction_empty_messages(self):
+ """Test TTL extraction with empty message list"""
+ messages = []
+ ttl = extract_ttl_from_cached_messages(messages)
+ assert ttl is None
+
+ def test_ttl_extraction_none_content(self):
+ """Test TTL extraction when content is None"""
+ messages = [
+ {
+ "role": "user",
+ "content": None
+ }
+ ]
+ ttl = extract_ttl_from_cached_messages(messages)
+ assert ttl is None
+
+ def test_ttl_extraction_empty_content_list(self):
+ """Test TTL extraction when content list is empty"""
+ messages = [
+ {
+ "role": "user",
+ "content": []
+ }
+ ]
+ ttl = extract_ttl_from_cached_messages(messages)
+ assert ttl is None
+
+ def test_ttl_validation_type_conversion(self):
+ """Test TTL validation handles type conversion properly"""
+ # Test that numeric TTL gets converted to string
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Cached content",
+ "cache_control": {"type": "ephemeral", "ttl": "3600s"}
+ }
+ ]
+ }
+ ]
+
+ ttl = extract_ttl_from_cached_messages(messages)
+ assert isinstance(ttl, str)
+ assert ttl == "3600s"
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
\ No newline at end of file