diff --git a/litellm/llms/gemini/count_tokens/handler.py b/litellm/llms/gemini/count_tokens/handler.py index a0782a838b..4d6c7fd886 100644 --- a/litellm/llms/gemini/count_tokens/handler.py +++ b/litellm/llms/gemini/count_tokens/handler.py @@ -11,7 +11,39 @@ if TYPE_CHECKING: else: GenerateContentContentListUnionDict = Any + class GoogleAIStudioTokenCounter: + def _clean_contents_for_gemini_api(self, contents: Any) -> Any: + """ + Clean up contents to remove unsupported fields for the Gemini API. + + The Google Gemini API doesn't recognize the 'id' field in function responses, + so we need to remove it to prevent 400 Bad Request errors. + + Args: + contents: The contents to clean up + + Returns: + Cleaned contents with unsupported fields removed + """ + import copy + + from google.genai.types import FunctionResponse + + cleaned_contents = copy.deepcopy(contents) + + for content in cleaned_contents: + parts = content["parts"] + for part in parts: + if "functionResponse" in part: + function_response_data = part["functionResponse"] + function_response_part = FunctionResponse(**function_response_data) + function_response_part.id = None + part["functionResponse"] = function_response_part.model_dump( + exclude_none=True + ) + + return cleaned_contents def _construct_url(self, model: str, api_base: Optional[str] = None) -> str: """ @@ -20,7 +52,6 @@ class GoogleAIStudioTokenCounter: base_url = api_base or "https://generativelanguage.googleapis.com" return f"{base_url}/v1beta/models/{model}:countTokens" - async def validate_environment( self, api_base: Optional[str] = None, @@ -33,7 +64,8 @@ class GoogleAIStudioTokenCounter: Returns a Tuple of headers and url for the Google Gen AI Studio countTokens endpoint. """ from litellm.llms.gemini.google_genai.transformation import GoogleGenAIConfig - headers = GoogleGenAIConfig().validate_environment( + + headers = GoogleGenAIConfig().validate_environment( api_key=api_key, headers=headers, model=model, @@ -54,7 +86,7 @@ class GoogleAIStudioTokenCounter: ) -> Dict[str, Any]: """ Count tokens using Google Gen AI Studio countTokens endpoint. - + Args: contents: The content to count tokens for (Google Gen AI format) Example: [{"parts": [{"text": "Hello world"}]}] @@ -63,7 +95,7 @@ class GoogleAIStudioTokenCounter: api_base: Optional API base URL (defaults to Google Gen AI Studio) timeout: Optional timeout for the request **kwargs: Additional parameters - + Returns: Dict containing token count information from Google Gen AI Studio API. Example response: @@ -77,14 +109,13 @@ class GoogleAIStudioTokenCounter: } ] } - + Raises: ValueError: If API key is missing litellm.APIError: If the API call fails litellm.APIConnectionError: If the connection fails Exception: For any other unexpected errors """ - # Set up API base URL # Prepare headers headers, url = await self.validate_environment( @@ -94,44 +125,39 @@ class GoogleAIStudioTokenCounter: model=model, litellm_params=kwargs, ) - - # Prepare request body - request_body = { - "contents": contents - } - + + # Prepare request body - clean up contents to remove unsupported fields + cleaned_contents = self._clean_contents_for_gemini_api(contents) + request_body = {"contents": cleaned_contents} + async_httpx_client = get_async_httpx_client( llm_provider=LlmProviders.GEMINI, ) try: response = await async_httpx_client.post( - url=url, - headers=headers, - json=request_body + url=url, headers=headers, json=request_body ) - + # Check for HTTP errors response.raise_for_status() - + # Parse response result = response.json() return result - + except httpx.HTTPStatusError as e: error_msg = f"Google Gen AI Studio API error: {e.response.status_code} - {e.response.text}" raise litellm.APIError( message=error_msg, llm_provider="gemini", model=model, - status_code=e.response.status_code + status_code=e.response.status_code, ) from e except httpx.RequestError as e: error_msg = f"Request to Google Gen AI Studio failed: {str(e)}" raise litellm.APIConnectionError( - message=error_msg, - llm_provider="gemini", - model=model + message=error_msg, llm_provider="gemini", model=model ) from e except Exception as e: error_msg = f"Unexpected error during token counting: {str(e)}" diff --git a/tests/test_litellm/llms/gemini/test_gemini_common_utils.py b/tests/test_litellm/llms/gemini/test_gemini_common_utils.py index 3b5b683d2d..c31ff308c6 100644 --- a/tests/test_litellm/llms/gemini/test_gemini_common_utils.py +++ b/tests/test_litellm/llms/gemini/test_gemini_common_utils.py @@ -157,4 +157,82 @@ class TestGoogleAIStudioTokenCounter: mock_acount_tokens.assert_called_once_with( model=model_to_use, contents=contents - ) \ No newline at end of file + ) + + def test_clean_contents_for_gemini_api_removes_id_field(self): + """Test that _clean_contents_for_gemini_api removes unsupported 'id' field from function responses""" + from litellm.llms.gemini.count_tokens.handler import GoogleAIStudioTokenCounter + + token_counter = GoogleAIStudioTokenCounter() + + # Test contents with function response containing 'id' field (camelCase) + contents_with_id = [ + { + "parts": [ + { + "text": "Hello world" + } + ], + "role": "user" + }, + { + "parts": [ + { + "functionResponse": { + "id": "read_many_files-1757526647518-730a691aac11c", # This should be removed + "name": "read_many_files", + "response": { + "output": "No files matching the criteria were found or all were skipped." + } + } + } + ], + "role": "user" + } + ] + + # Clean the contents + cleaned_contents = token_counter._clean_contents_for_gemini_api(contents_with_id) + + # Verify the 'id' field was removed + function_response = cleaned_contents[1]["parts"][0]["functionResponse"] + assert "id" not in function_response + assert "name" in function_response + assert "response" in function_response + assert function_response["name"] == "read_many_files" + assert function_response["response"]["output"] == "No files matching the criteria were found or all were skipped." + + + def test_clean_contents_for_gemini_api_preserves_other_fields(self): + """Test that _clean_contents_for_gemini_api preserves other fields and structure""" + from litellm.llms.gemini.count_tokens.handler import GoogleAIStudioTokenCounter + + token_counter = GoogleAIStudioTokenCounter() + + # Test contents without function responses + contents_without_function_response = [ + { + "parts": [ + { + "text": "This is a regular message" + } + ], + "role": "user" + }, + { + "parts": [ + { + "text": "This is a model response" + } + ], + "role": "model" + } + ] + + # Clean the contents + cleaned_contents = token_counter._clean_contents_for_gemini_api(contents_without_function_response) + + # Verify the contents are unchanged + assert cleaned_contents == contents_without_function_response + +