From 094ce8f772cf6c8cd540a71052008d0cf78bc44f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20D=C3=ADaz?= <40875838+marcelodiaz558@users.noreply.github.com> Date: Tue, 15 Jul 2025 01:30:54 -0400 Subject: [PATCH] feat(gemini): Add custom TTL support for context caching (#9810) (#12541) - Add ttl parameter to cache_control for Gemini models - Support Google's TTL format (e.g., '3600s', '7200s') - Implement robust TTL extraction and validation - Extract TTL before system message transformation to handle all cases - Add comprehensive test suite with 17 test cases in tests/test_litellm/ - Update documentation with TTL usage examples - Maintain backward compatibility with existing cache_control usage Fixes #9810 --- docs/my-website/docs/providers/gemini.md | 146 ++++++- .../context_caching/transformation.py | 78 +++- tests/llm_translation/test_gemini.py | 148 +++++++ .../test_context_caching_ttl.py | 360 ++++++++++++++++++ 4 files changed, 727 insertions(+), 5 deletions(-) create mode 100644 tests/test_litellm/llms/vertex_ai/context_caching/test_context_caching_ttl.py diff --git a/docs/my-website/docs/providers/gemini.md b/docs/my-website/docs/providers/gemini.md index 0d388a4151..9376144cc8 100644 --- a/docs/my-website/docs/providers/gemini.md +++ b/docs/my-website/docs/providers/gemini.md @@ -1219,12 +1219,38 @@ Use Google AI Studio context caching is supported by in your message content block. +### Custom TTL Support + +You can now specify a custom Time-To-Live (TTL) for your cached content using the `ttl` parameter: + +```bash +{ + { + "role": "system", + "content": ..., + "cache_control": { + "type": "ephemeral", + "ttl": "3600s" # 👈 Cache for 1 hour + } + }, + ... +} +``` + +**TTL Format Requirements:** +- Must be a string ending with 's' for seconds +- Must contain a positive number (can be decimal) +- Examples: `"3600s"` (1 hour), `"7200s"` (2 hours), `"1800s"` (30 minutes), `"1.5s"` (1.5 seconds) + +**TTL Behavior:** +- If multiple cached messages have different TTLs, the first valid TTL encountered will be used +- Invalid TTL formats are ignored and the cache will use Google's default expiration time +- If no TTL is specified, Google's default cache expiration (approximately 1 hour) applies + ### Architecture Diagram - - **Notes:** - [Relevant code](https://github.com/BerriAI/litellm/blob/main/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py#L255) @@ -1233,7 +1259,6 @@ in your message content block. - If multiple non-continuous blocks contain `cache_control` - the first continuous block will be used. (sent to `/cachedContent` in the [Gemini format](https://ai.google.dev/api/caching#cache_create-SHELL)) - - The raw request to Gemini's `/generateContent` endpoint looks like this: ```bash @@ -1253,7 +1278,6 @@ curl -X POST "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5 ``` - ### Example Usage @@ -1293,6 +1317,48 @@ for _ in range(2): print(resp.usage) # 👈 2nd usage block will be less, since cached tokens used ``` + + + +```python +from litellm import completion + +# Cache for 2 hours (7200 seconds) +resp = completion( + model="gemini/gemini-1.5-pro", + messages=[ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" * 4000, + "cache_control": { + "type": "ephemeral", + "ttl": "7200s" # 👈 Cache for 2 hours + }, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": { + "type": "ephemeral", + "ttl": "3600s" # 👈 This TTL will be ignored (first one is used) + }, + } + ], + } + ] +) + +print(resp.usage) +``` + @@ -1350,6 +1416,44 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ }' ``` + + +```bash +curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "gemini-1.5-pro", + "messages": [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" * 4000, + "cache_control": { + "type": "ephemeral", + "ttl": "7200s" + } + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": { + "type": "ephemeral", + "ttl": "3600s" + } + } + ] + } + ] +}' +``` + ```python @@ -1382,6 +1486,40 @@ response = await client.chat.completions.create( ``` + + + +```python +import openai +client = openai.AsyncOpenAI( + api_key="anything", # litellm proxy api key + base_url="http://0.0.0.0:4000" # litellm proxy base url +) + +response = await client.chat.completions.create( + model="gemini-1.5-pro", + messages=[ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" * 4000, + "cache_control": { + "type": "ephemeral", + "ttl": "7200s" # Cache for 2 hours + } + } + ], + }, + { + "role": "user", + "content": "what are the key terms and conditions in this agreement?", + }, + ] +) +``` + diff --git a/litellm/llms/vertex_ai/context_caching/transformation.py b/litellm/llms/vertex_ai/context_caching/transformation.py index 83c15029b2..f3ca699546 100644 --- a/litellm/llms/vertex_ai/context_caching/transformation.py +++ b/litellm/llms/vertex_ai/context_caching/transformation.py @@ -4,7 +4,8 @@ Transformation logic for context caching. Why separate file? Make it easy to see how transformation works """ -from typing import List, Tuple +import re +from typing import List, Optional, Tuple from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.vertex_ai import CachedContentRequestBody @@ -47,6 +48,72 @@ def get_first_continuous_block_idx( return len(filtered_messages) - 1 +def extract_ttl_from_cached_messages(messages: List[AllMessageValues]) -> Optional[str]: + """ + Extract TTL from cached messages. Returns the first valid TTL found. + + Args: + messages: List of messages to extract TTL from + + Returns: + Optional[str]: TTL string in format "3600s" or None if not found/invalid + """ + for message in messages: + if not is_cached_message(message): + continue + + content = message.get("content") + if not content or isinstance(content, str): + continue + + for content_item in content: + # Type check to ensure content_item is a dictionary before calling .get() + if not isinstance(content_item, dict): + continue + + cache_control = content_item.get("cache_control") + if not cache_control or not isinstance(cache_control, dict): + continue + + if cache_control.get("type") != "ephemeral": + continue + + ttl = cache_control.get("ttl") + if ttl and _is_valid_ttl_format(ttl): + return str(ttl) + + return None + + +def _is_valid_ttl_format(ttl: str) -> bool: + """ + Validate TTL format. Should be a string ending with 's' for seconds. + Examples: "3600s", "7200s", "1.5s" + + Args: + ttl: TTL string to validate + + Returns: + bool: True if valid format, False otherwise + """ + if not isinstance(ttl, str): + return False + + # TTL should end with 's' and contain a valid number before it + pattern = r'^([0-9]*\.?[0-9]+)s$' + match = re.match(pattern, ttl) + + if not match: + return False + + try: + # Ensure the numeric part is valid and positive + numeric_part = float(match.group(1)) + return numeric_part > 0 + except ValueError: + return False + + def separate_cached_messages( messages: List[AllMessageValues], ) -> Tuple[List[AllMessageValues], List[AllMessageValues]]: @@ -90,6 +157,9 @@ def separate_cached_messages( def transform_openai_messages_to_gemini_context_caching( model: str, messages: List[AllMessageValues], cache_key: str ) -> CachedContentRequestBody: + # Extract TTL from cached messages BEFORE system message transformation + ttl = extract_ttl_from_cached_messages(messages) + supports_system_message = get_supports_system_message( model=model, custom_llm_provider="gemini" ) @@ -99,11 +169,17 @@ def transform_openai_messages_to_gemini_context_caching( ) transformed_messages = _gemini_convert_messages_with_history(messages=new_messages) + data = CachedContentRequestBody( contents=transformed_messages, model="models/{}".format(model), displayName=cache_key, ) + + # Add TTL if present and valid + if ttl: + data["ttl"] = ttl + if transformed_system_messages is not None: data["system_instruction"] = transformed_system_messages diff --git a/tests/llm_translation/test_gemini.py b/tests/llm_translation/test_gemini.py index c41d914e53..ae279bbde4 100644 --- a/tests/llm_translation/test_gemini.py +++ b/tests/llm_translation/test_gemini.py @@ -12,6 +12,7 @@ sys.path.insert( from base_llm_unit_tests import BaseLLMChatTest from litellm.llms.vertex_ai.context_caching.transformation import ( separate_cached_messages, + transform_openai_messages_to_gemini_context_caching, ) import litellm from litellm import completion @@ -67,6 +68,153 @@ class TestGoogleAIStudioGemini(BaseLLMChatTest): print(f"response={response}") +def test_gemini_context_caching_with_ttl(): + """Test Gemini context caching with TTL support""" + + # Test case 1: Basic TTL functionality + messages_with_ttl = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" * 400, + "cache_control": {"type": "ephemeral", "ttl": "3600s"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral", "ttl": "7200s"}, + } + ], + } + ] + + # Test the transformation function directly + result = transform_openai_messages_to_gemini_context_caching( + model="gemini-1.5-pro", + messages=messages_with_ttl, + cache_key="test-ttl-cache-key" + ) + + # Verify TTL is properly included in the result + assert "ttl" in result + assert result["ttl"] == "3600s" # Should use the first valid TTL found + assert result["model"] == "models/gemini-1.5-pro" + assert result["displayName"] == "test-ttl-cache-key" + + # Test case 2: Invalid TTL should be ignored + messages_invalid_ttl = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Cached content with invalid TTL", + "cache_control": {"type": "ephemeral", "ttl": "invalid_ttl"}, + } + ], + } + ] + + result_invalid = transform_openai_messages_to_gemini_context_caching( + model="gemini-1.5-pro", + messages=messages_invalid_ttl, + cache_key="test-invalid-ttl" + ) + + # Verify invalid TTL is not included + assert "ttl" not in result_invalid + assert result_invalid["model"] == "models/gemini-1.5-pro" + assert result_invalid["displayName"] == "test-invalid-ttl" + + # Test case 3: Messages without TTL should work normally + messages_no_ttl = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Cached content without TTL", + "cache_control": {"type": "ephemeral"}, + } + ], + } + ] + + result_no_ttl = transform_openai_messages_to_gemini_context_caching( + model="gemini-1.5-pro", + messages=messages_no_ttl, + cache_key="test-no-ttl" + ) + + # Verify no TTL field is present when not specified + assert "ttl" not in result_no_ttl + assert result_no_ttl["model"] == "models/gemini-1.5-pro" + assert result_no_ttl["displayName"] == "test-no-ttl" + + # Test case 4: Mixed messages with some having TTL + messages_mixed = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "System message with TTL", + "cache_control": {"type": "ephemeral", "ttl": "1800s"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "User message without TTL", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": "Assistant response without cache control" + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Another user message", + "cache_control": {"type": "ephemeral", "ttl": "900s"}, + } + ], + } + ] + + # Test separation of cached messages + cached_messages, non_cached_messages = separate_cached_messages(messages_mixed) + assert len(cached_messages) > 0 + assert len(non_cached_messages) > 0 + + # Test transformation with mixed messages + result_mixed = transform_openai_messages_to_gemini_context_caching( + model="gemini-1.5-pro", + messages=messages_mixed, + cache_key="test-mixed-ttl" + ) + + # Should pick up the first valid TTL + assert "ttl" in result_mixed + assert result_mixed["ttl"] == "1800s" + assert result_mixed["model"] == "models/gemini-1.5-pro" + assert result_mixed["displayName"] == "test-mixed-ttl" + + def test_gemini_context_caching_separate_messages(): messages = [ # System Message diff --git a/tests/test_litellm/llms/vertex_ai/context_caching/test_context_caching_ttl.py b/tests/test_litellm/llms/vertex_ai/context_caching/test_context_caching_ttl.py new file mode 100644 index 0000000000..cce6055ab3 --- /dev/null +++ b/tests/test_litellm/llms/vertex_ai/context_caching/test_context_caching_ttl.py @@ -0,0 +1,360 @@ +import pytest +from litellm.llms.vertex_ai.context_caching.transformation import ( + extract_ttl_from_cached_messages, + _is_valid_ttl_format, + transform_openai_messages_to_gemini_context_caching, +) + + +class TestTTLValidation: + """Test TTL format validation""" + + def test_valid_ttl_formats(self): + """Test various valid TTL formats""" + valid_ttls = [ + "3600s", + "1s", + "7200s", + "1.5s", + "0.1s", + "86400s", + "123.456s" + ] + + for ttl in valid_ttls: + assert _is_valid_ttl_format(ttl), f"TTL {ttl} should be valid" + + def test_invalid_ttl_formats(self): + """Test various invalid TTL formats""" + invalid_ttls = [ + "3600", # missing 's' + "s", # missing number + "-1s", # negative number + "0s", # zero + "3600m", # wrong unit + "abc.s", # invalid number + "", # empty string + "3600.s", # invalid decimal + "3600 s", # space + "3600ss", # extra 's' + None, # None + 123, # not a string + ] + + for ttl in invalid_ttls: + assert not _is_valid_ttl_format(ttl), f"TTL {ttl} should be invalid" + + +class TestTTLExtraction: + """Test TTL extraction from cached messages""" + + def test_extract_ttl_from_single_message(self): + """Test extracting TTL from a single cached message""" + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "This is cached content", + "cache_control": {"type": "ephemeral", "ttl": "3600s"} + } + ] + } + ] + + ttl = extract_ttl_from_cached_messages(messages) + assert ttl == "3600s" + + def test_extract_ttl_from_multiple_messages(self): + """Test extracting TTL from multiple cached messages (should return first valid one)""" + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral", "ttl": "7200s"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "User message", + "cache_control": {"type": "ephemeral", "ttl": "3600s"} + } + ] + } + ] + + ttl = extract_ttl_from_cached_messages(messages) + assert ttl == "7200s" # Should return the first valid TTL found + + def test_extract_ttl_no_cache_control(self): + """Test extracting TTL from messages without cache_control""" + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Regular message without cache control" + } + ] + } + ] + + ttl = extract_ttl_from_cached_messages(messages) + assert ttl is None + + def test_extract_ttl_invalid_format(self): + """Test extracting TTL with invalid format""" + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Cached content with invalid TTL", + "cache_control": {"type": "ephemeral", "ttl": "invalid"} + } + ] + } + ] + + ttl = extract_ttl_from_cached_messages(messages) + assert ttl is None + + def test_extract_ttl_missing_ttl_field(self): + """Test extracting TTL when ttl field is missing""" + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Cached content without TTL field", + "cache_control": {"type": "ephemeral"} + } + ] + } + ] + + ttl = extract_ttl_from_cached_messages(messages) + assert ttl is None + + def test_extract_ttl_mixed_valid_invalid(self): + """Test extracting TTL when some messages have valid TTL and others don't""" + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "System message with invalid TTL", + "cache_control": {"type": "ephemeral", "ttl": "invalid"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "User message with valid TTL", + "cache_control": {"type": "ephemeral", "ttl": "3600s"} + } + ] + } + ] + + ttl = extract_ttl_from_cached_messages(messages) + assert ttl == "3600s" # Should return the first valid TTL found + + def test_extract_ttl_string_content(self): + """Test extracting TTL when message content is a string (not a list)""" + messages = [ + { + "role": "user", + "content": "String content" + } + ] + + ttl = extract_ttl_from_cached_messages(messages) + assert ttl is None + + +class TestTransformationWithTTL: + """Test the complete transformation with TTL support""" + + def test_transform_with_valid_ttl(self): + """Test transformation includes TTL when provided""" + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Cached content", + "cache_control": {"type": "ephemeral", "ttl": "3600s"} + } + ] + } + ] + + result = transform_openai_messages_to_gemini_context_caching( + model="gemini-1.5-pro", + messages=messages, + cache_key="test-cache-key" + ) + + assert "ttl" in result + assert result["ttl"] == "3600s" + assert result["model"] == "models/gemini-1.5-pro" + assert result["displayName"] == "test-cache-key" + + def test_transform_without_ttl(self): + """Test transformation without TTL""" + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Cached content", + "cache_control": {"type": "ephemeral"} + } + ] + } + ] + + result = transform_openai_messages_to_gemini_context_caching( + model="gemini-1.5-pro", + messages=messages, + cache_key="test-cache-key" + ) + + assert "ttl" not in result + assert result["model"] == "models/gemini-1.5-pro" + assert result["displayName"] == "test-cache-key" + + def test_transform_with_invalid_ttl(self): + """Test transformation with invalid TTL (should be ignored)""" + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Cached content", + "cache_control": {"type": "ephemeral", "ttl": "invalid"} + } + ] + } + ] + + result = transform_openai_messages_to_gemini_context_caching( + model="gemini-1.5-pro", + messages=messages, + cache_key="test-cache-key" + ) + + assert "ttl" not in result + assert result["model"] == "models/gemini-1.5-pro" + assert result["displayName"] == "test-cache-key" + + def test_transform_with_system_message_and_ttl(self): + """Test transformation with system message and TTL""" + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "System instruction", + "cache_control": {"type": "ephemeral", "ttl": "7200s"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "User message" + } + ] + } + ] + + result = transform_openai_messages_to_gemini_context_caching( + model="gemini-1.5-pro", + messages=messages, + cache_key="test-cache-key" + ) + + assert "ttl" in result + assert result["ttl"] == "7200s" + assert "system_instruction" in result + assert result["model"] == "models/gemini-1.5-pro" + assert result["displayName"] == "test-cache-key" + + +class TestEdgeCases: + """Test edge cases and error conditions""" + + def test_ttl_extraction_empty_messages(self): + """Test TTL extraction with empty message list""" + messages = [] + ttl = extract_ttl_from_cached_messages(messages) + assert ttl is None + + def test_ttl_extraction_none_content(self): + """Test TTL extraction when content is None""" + messages = [ + { + "role": "user", + "content": None + } + ] + ttl = extract_ttl_from_cached_messages(messages) + assert ttl is None + + def test_ttl_extraction_empty_content_list(self): + """Test TTL extraction when content list is empty""" + messages = [ + { + "role": "user", + "content": [] + } + ] + ttl = extract_ttl_from_cached_messages(messages) + assert ttl is None + + def test_ttl_validation_type_conversion(self): + """Test TTL validation handles type conversion properly""" + # Test that numeric TTL gets converted to string + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Cached content", + "cache_control": {"type": "ephemeral", "ttl": "3600s"} + } + ] + } + ] + + ttl = extract_ttl_from_cached_messages(messages) + assert isinstance(ttl, str) + assert ttl == "3600s" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file