Merge in - Gemini streaming - thinking content parsing - return in reasoning_content (#11298)

* fix(base_routing_strategy.py): compress increments to redis - reduces write ops

* fix(base_routing_strategy.py): make get and reset in memory keys atomic

* fix(base_routing_strategy.py): don't reset keys - causes discrepency on subsequent requests to instance

* fix(parallel_request_limiter.py): retrieve values of previous slots from cache

more accurate rate limiting with sliding window

* fix: fix test

* fix: fix linting error

* fix(gemini/): fix streaming handler for function calling

Closes https://github.com/BerriAI/litellm/pull/11294

* fix: fix linting error

* test: update test

* fix(vertex_and_google_ai_studio_gemini.py): return none on skipped chunk

* fix(streaming_handler.py): skip none chunks on async streaming
This commit is contained in:
Krish Dholakia
2025-06-02 23:14:38 -07:00
committed by GitHub
parent a366f9247a
commit ccc085faee
6 changed files with 114 additions and 37 deletions
@@ -46,6 +46,9 @@ DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage(
content="Please continue.", role="assistant"
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LoggingClass
def handle_any_messages_to_chat_completion_str_messages_conversion(
messages: Any,
@@ -618,6 +621,20 @@ def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
return file_ids
def check_is_function_call(logging_obj: "LoggingClass") -> bool:
from litellm.litellm_core_utils.prompt_templates.common_utils import (
is_function_call,
)
if hasattr(logging_obj, "optional_params") and isinstance(
logging_obj.optional_params, dict
):
if is_function_call(logging_obj.optional_params):
return True
return False
def filter_value_from_dict(dictionary: dict, key: str, depth: int = 0) -> Any:
"""
Filters a value from a dictionary
@@ -670,3 +687,4 @@ def migrate_file_to_image_url(
image_url_object["image_url"]["format"] = format
return image_url_object
@@ -1359,6 +1359,7 @@ class CustomStreamWrapper:
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
## CHECK FOR TOOL USE
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
if self.is_function_call is True: # user passed in 'functions' param
completion_obj["function_call"] = completion_obj["tool_calls"][0][
@@ -1633,7 +1634,8 @@ class CustomStreamWrapper:
if is_async_iterable(self.completion_stream):
async for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
continue # skip None chunks
elif (
self.custom_llm_provider == "gemini"
and hasattr(chunk, "parts")
@@ -1642,7 +1644,9 @@ class CustomStreamWrapper:
continue
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
# __anext__ also calls async_success_handler, which does logging
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
verbose_logger.debug(
f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}"
)
processed_chunk: Optional[ModelResponseStream] = self.chunk_creator(
chunk=chunk
@@ -1267,7 +1267,9 @@ async def make_call(
)
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
streaming_response=response.aiter_lines(),
sync_stream=False,
logging_obj=logging_obj,
)
# LOGGING
logging_obj.post_call(
@@ -1305,7 +1307,9 @@ def make_sync_call(
)
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
streaming_response=response.iter_lines(),
sync_stream=True,
logging_obj=logging_obj,
)
# LOGGING
@@ -1726,11 +1730,19 @@ class VertexLLM(VertexBase):
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
def __init__(
self, streaming_response, sync_stream: bool, logging_obj: LoggingClass
):
from litellm.litellm_core_utils.prompt_templates.common_utils import (
check_is_function_call,
)
self.streaming_response = streaming_response
self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json"
self.accumulated_json = ""
self.sent_first_chunk = False
self.logging_obj = logging_obj
self.is_function_call = check_is_function_call(logging_obj)
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
@@ -1794,11 +1806,23 @@ class ModelResponseIterator:
},
)
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=False,
finish_reason=finish_reason,
args: Dict[str, Any] = {
"content": text or None,
"reasoning_content": reasoning_content,
}
if self.is_function_call and tool_use is not None:
args["function_call"] = tool_use["function"]
elif tool_use is not None:
args["tool_calls"] = [tool_use]
returned_chunk = ModelResponseStream(
choices=[
StreamingChoices(
index=0,
delta=Delta(**args),
finish_reason=finish_reason,
)
],
usage=usage,
index=0,
)
@@ -1811,7 +1835,7 @@ class ModelResponseIterator:
self.response_iterator = self.streaming_response
return self
def handle_valid_json_chunk(self, chunk: str) -> GenericStreamingChunk:
def handle_valid_json_chunk(self, chunk: str) -> Optional[ModelResponseStream]:
chunk = chunk.strip()
try:
json_chunk = json.loads(chunk)
@@ -1829,7 +1853,9 @@ class ModelResponseIterator:
return self.chunk_parser(chunk=json_chunk)
def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk:
def handle_accumulated_json_chunk(
self, chunk: str
) -> Optional[ModelResponseStream]:
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
message = chunk.replace("\n\n", "")
@@ -1843,16 +1869,9 @@ class ModelResponseIterator:
return self.chunk_parser(chunk=_data)
except json.JSONDecodeError:
# If it's not valid JSON yet, continue to the next event
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
return None
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
def _common_chunk_parsing_logic(self, chunk: str) -> Optional[ModelResponseStream]:
try:
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
if len(chunk) > 0:
@@ -1865,15 +1884,7 @@ class ModelResponseIterator:
return self.handle_valid_json_chunk(chunk=chunk)
elif self.chunk_type == "accumulated_json":
return self.handle_accumulated_json_chunk(chunk=chunk)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
return None
except Exception:
raise
@@ -190,6 +190,7 @@ class VertexPassthroughLoggingHandler:
vertex_iterator = VertexModelResponseIterator(
streaming_response=None,
sync_stream=False,
logging_obj=litellm_logging_obj,
)
litellm_custom_stream_wrapper = litellm.CustomStreamWrapper(
completion_stream=vertex_iterator,
+1 -1
View File
@@ -712,7 +712,7 @@ async def test_acompletion_claude_2_stream():
@pytest.mark.flaky(retries=3, delay=1)
async def test_completion_gemini_stream(sync_mode):
try:
litellm.set_verbose = True
litellm._turn_on_debug()
print("Streaming gemini response")
function1 = [
{
@@ -322,6 +322,8 @@ def test_streaming_chunk_includes_reasoning_tokens():
ModelResponseIterator,
)
litellm_logging = MagicMock()
# Simulate a streaming chunk as would be received from Gemini
chunk = {
"candidates": [{"content": {"parts": [{"text": "Hello"}]}}],
@@ -332,14 +334,54 @@ def test_streaming_chunk_includes_reasoning_tokens():
"thoughtsTokenCount": 3,
},
}
iterator = ModelResponseIterator(streaming_response=[], sync_stream=True)
iterator = ModelResponseIterator(
streaming_response=[], sync_stream=True, logging_obj=litellm_logging
)
streaming_chunk = iterator.chunk_parser(chunk)
assert streaming_chunk["usage"] is not None
assert streaming_chunk["usage"]["prompt_tokens"] == 5
assert streaming_chunk["usage"]["completion_tokens"] == 7
assert streaming_chunk["usage"]["total_tokens"] == 12
assert streaming_chunk.usage is not None
assert streaming_chunk.usage.prompt_tokens == 5
assert streaming_chunk.usage.completion_tokens == 7
assert streaming_chunk.usage.total_tokens == 12
assert streaming_chunk.usage.completion_tokens_details.reasoning_tokens == 3
def test_streaming_chunk_includes_reasoning_content():
"""
Ensure that when Gemini returns a chunk with `thought=True`, the parser maps it to `reasoning_content`.
"""
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator,
)
litellm_logging = MagicMock()
# Simulate a streaming chunk from Gemini which contains reasoning (thought) content
chunk = {
"candidates": [
{
"content": {
"parts": [
{
"text": "I'm thinking through the problem...",
"thought": True,
}
]
}
}
],
"usageMetadata": {},
}
iterator = ModelResponseIterator(
streaming_response=[], sync_stream=True, logging_obj=litellm_logging
)
streaming_chunk = iterator.chunk_parser(chunk)
# The text content should be empty and reasoning_content should be populated
assert streaming_chunk.choices[0].delta.content is None
assert (
streaming_chunk["usage"]["completion_tokens_details"]["reasoning_tokens"] == 3
streaming_chunk.choices[0].delta.reasoning_content
== "I'm thinking through the problem..."
)
@@ -393,6 +435,7 @@ def test_vertex_ai_map_thinking_param_with_budget_tokens_0():
"thinkingBudget": 100,
}
def test_vertex_ai_map_tools():
v = VertexGeminiConfig()
tools = v._map_function(value=[{"code_execution": {}}])