From ccc085faeea4dcbffa824e962e83c789c3d4e231 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Mon, 2 Jun 2025 23:14:38 -0700 Subject: [PATCH] Merge in - Gemini streaming - thinking content parsing - return in `reasoning_content` (#11298) * fix(base_routing_strategy.py): compress increments to redis - reduces write ops * fix(base_routing_strategy.py): make get and reset in memory keys atomic * fix(base_routing_strategy.py): don't reset keys - causes discrepency on subsequent requests to instance * fix(parallel_request_limiter.py): retrieve values of previous slots from cache more accurate rate limiting with sliding window * fix: fix test * fix: fix linting error * fix(gemini/): fix streaming handler for function calling Closes https://github.com/BerriAI/litellm/pull/11294 * fix: fix linting error * test: update test * fix(vertex_and_google_ai_studio_gemini.py): return none on skipped chunk * fix(streaming_handler.py): skip none chunks on async streaming --- .../prompt_templates/common_utils.py | 18 +++++ .../litellm_core_utils/streaming_handler.py | 8 ++- .../vertex_and_google_ai_studio_gemini.py | 67 +++++++++++-------- .../vertex_passthrough_logging_handler.py | 1 + tests/local_testing/test_streaming.py | 2 +- ...test_vertex_and_google_ai_studio_gemini.py | 55 +++++++++++++-- 6 files changed, 114 insertions(+), 37 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index c94b276a7b..acac97bd3e 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -46,6 +46,9 @@ DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage( content="Please continue.", role="assistant" ) +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LoggingClass + def handle_any_messages_to_chat_completion_str_messages_conversion( messages: Any, @@ -618,6 +621,20 @@ def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]: return file_ids + +def check_is_function_call(logging_obj: "LoggingClass") -> bool: + from litellm.litellm_core_utils.prompt_templates.common_utils import ( + is_function_call, + ) + + if hasattr(logging_obj, "optional_params") and isinstance( + logging_obj.optional_params, dict + ): + if is_function_call(logging_obj.optional_params): + return True + + return False + def filter_value_from_dict(dictionary: dict, key: str, depth: int = 0) -> Any: """ Filters a value from a dictionary @@ -670,3 +687,4 @@ def migrate_file_to_image_url( image_url_object["image_url"]["format"] = format return image_url_object + diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index a3f6bffa55..ea9d867256 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1359,6 +1359,7 @@ class CustomStreamWrapper: print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}") ## CHECK FOR TOOL USE + if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0: if self.is_function_call is True: # user passed in 'functions' param completion_obj["function_call"] = completion_obj["tool_calls"][0][ @@ -1633,7 +1634,8 @@ class CustomStreamWrapper: if is_async_iterable(self.completion_stream): async for chunk in self.completion_stream: if chunk == "None" or chunk is None: - raise Exception + continue # skip None chunks + elif ( self.custom_llm_provider == "gemini" and hasattr(chunk, "parts") @@ -1642,7 +1644,9 @@ class CustomStreamWrapper: continue # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. # __anext__ also calls async_success_handler, which does logging - print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}") + verbose_logger.debug( + f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}" + ) processed_chunk: Optional[ModelResponseStream] = self.chunk_creator( chunk=chunk diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index c568dfb205..ba89bb073e 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -1267,7 +1267,9 @@ async def make_call( ) completion_stream = ModelResponseIterator( - streaming_response=response.aiter_lines(), sync_stream=False + streaming_response=response.aiter_lines(), + sync_stream=False, + logging_obj=logging_obj, ) # LOGGING logging_obj.post_call( @@ -1305,7 +1307,9 @@ def make_sync_call( ) completion_stream = ModelResponseIterator( - streaming_response=response.iter_lines(), sync_stream=True + streaming_response=response.iter_lines(), + sync_stream=True, + logging_obj=logging_obj, ) # LOGGING @@ -1726,11 +1730,19 @@ class VertexLLM(VertexBase): class ModelResponseIterator: - def __init__(self, streaming_response, sync_stream: bool): + def __init__( + self, streaming_response, sync_stream: bool, logging_obj: LoggingClass + ): + from litellm.litellm_core_utils.prompt_templates.common_utils import ( + check_is_function_call, + ) + self.streaming_response = streaming_response self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json" self.accumulated_json = "" self.sent_first_chunk = False + self.logging_obj = logging_obj + self.is_function_call = check_is_function_call(logging_obj) def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: try: @@ -1794,11 +1806,23 @@ class ModelResponseIterator: }, ) - returned_chunk = GenericStreamingChunk( - text=text, - tool_use=tool_use, - is_finished=False, - finish_reason=finish_reason, + args: Dict[str, Any] = { + "content": text or None, + "reasoning_content": reasoning_content, + } + if self.is_function_call and tool_use is not None: + args["function_call"] = tool_use["function"] + elif tool_use is not None: + args["tool_calls"] = [tool_use] + + returned_chunk = ModelResponseStream( + choices=[ + StreamingChoices( + index=0, + delta=Delta(**args), + finish_reason=finish_reason, + ) + ], usage=usage, index=0, ) @@ -1811,7 +1835,7 @@ class ModelResponseIterator: self.response_iterator = self.streaming_response return self - def handle_valid_json_chunk(self, chunk: str) -> GenericStreamingChunk: + def handle_valid_json_chunk(self, chunk: str) -> Optional[ModelResponseStream]: chunk = chunk.strip() try: json_chunk = json.loads(chunk) @@ -1829,7 +1853,9 @@ class ModelResponseIterator: return self.chunk_parser(chunk=json_chunk) - def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk: + def handle_accumulated_json_chunk( + self, chunk: str + ) -> Optional[ModelResponseStream]: chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" message = chunk.replace("\n\n", "") @@ -1843,16 +1869,9 @@ class ModelResponseIterator: return self.chunk_parser(chunk=_data) except json.JSONDecodeError: # If it's not valid JSON yet, continue to the next event - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) + return None - def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk: + def _common_chunk_parsing_logic(self, chunk: str) -> Optional[ModelResponseStream]: try: chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" if len(chunk) > 0: @@ -1865,15 +1884,7 @@ class ModelResponseIterator: return self.handle_valid_json_chunk(chunk=chunk) elif self.chunk_type == "accumulated_json": return self.handle_accumulated_json_chunk(chunk=chunk) - - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) + return None except Exception: raise diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index a20f39e65c..afac84326f 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -190,6 +190,7 @@ class VertexPassthroughLoggingHandler: vertex_iterator = VertexModelResponseIterator( streaming_response=None, sync_stream=False, + logging_obj=litellm_logging_obj, ) litellm_custom_stream_wrapper = litellm.CustomStreamWrapper( completion_stream=vertex_iterator, diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index ffeda23edd..8112357e0f 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -712,7 +712,7 @@ async def test_acompletion_claude_2_stream(): @pytest.mark.flaky(retries=3, delay=1) async def test_completion_gemini_stream(sync_mode): try: - litellm.set_verbose = True + litellm._turn_on_debug() print("Streaming gemini response") function1 = [ { diff --git a/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py b/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py index 18ffd7ca60..f62408b1a1 100644 --- a/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py +++ b/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py @@ -322,6 +322,8 @@ def test_streaming_chunk_includes_reasoning_tokens(): ModelResponseIterator, ) + litellm_logging = MagicMock() + # Simulate a streaming chunk as would be received from Gemini chunk = { "candidates": [{"content": {"parts": [{"text": "Hello"}]}}], @@ -332,14 +334,54 @@ def test_streaming_chunk_includes_reasoning_tokens(): "thoughtsTokenCount": 3, }, } - iterator = ModelResponseIterator(streaming_response=[], sync_stream=True) + iterator = ModelResponseIterator( + streaming_response=[], sync_stream=True, logging_obj=litellm_logging + ) streaming_chunk = iterator.chunk_parser(chunk) - assert streaming_chunk["usage"] is not None - assert streaming_chunk["usage"]["prompt_tokens"] == 5 - assert streaming_chunk["usage"]["completion_tokens"] == 7 - assert streaming_chunk["usage"]["total_tokens"] == 12 + assert streaming_chunk.usage is not None + assert streaming_chunk.usage.prompt_tokens == 5 + assert streaming_chunk.usage.completion_tokens == 7 + assert streaming_chunk.usage.total_tokens == 12 + assert streaming_chunk.usage.completion_tokens_details.reasoning_tokens == 3 + + +def test_streaming_chunk_includes_reasoning_content(): + """ + Ensure that when Gemini returns a chunk with `thought=True`, the parser maps it to `reasoning_content`. + """ + from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( + ModelResponseIterator, + ) + + litellm_logging = MagicMock() + + # Simulate a streaming chunk from Gemini which contains reasoning (thought) content + chunk = { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "I'm thinking through the problem...", + "thought": True, + } + ] + } + } + ], + "usageMetadata": {}, + } + + iterator = ModelResponseIterator( + streaming_response=[], sync_stream=True, logging_obj=litellm_logging + ) + streaming_chunk = iterator.chunk_parser(chunk) + + # The text content should be empty and reasoning_content should be populated + assert streaming_chunk.choices[0].delta.content is None assert ( - streaming_chunk["usage"]["completion_tokens_details"]["reasoning_tokens"] == 3 + streaming_chunk.choices[0].delta.reasoning_content + == "I'm thinking through the problem..." ) @@ -393,6 +435,7 @@ def test_vertex_ai_map_thinking_param_with_budget_tokens_0(): "thinkingBudget": 100, } + def test_vertex_ai_map_tools(): v = VertexGeminiConfig() tools = v._map_function(value=[{"code_execution": {}}])