mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 22:48:35 +00:00
Merge in - Gemini streaming - thinking content parsing - return in reasoning_content (#11298)
* fix(base_routing_strategy.py): compress increments to redis - reduces write ops * fix(base_routing_strategy.py): make get and reset in memory keys atomic * fix(base_routing_strategy.py): don't reset keys - causes discrepency on subsequent requests to instance * fix(parallel_request_limiter.py): retrieve values of previous slots from cache more accurate rate limiting with sliding window * fix: fix test * fix: fix linting error * fix(gemini/): fix streaming handler for function calling Closes https://github.com/BerriAI/litellm/pull/11294 * fix: fix linting error * test: update test * fix(vertex_and_google_ai_studio_gemini.py): return none on skipped chunk * fix(streaming_handler.py): skip none chunks on async streaming
This commit is contained in:
@@ -46,6 +46,9 @@ DEFAULT_ASSISTANT_CONTINUE_MESSAGE = ChatCompletionAssistantMessage(
|
||||
content="Please continue.", role="assistant"
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LoggingClass
|
||||
|
||||
|
||||
def handle_any_messages_to_chat_completion_str_messages_conversion(
|
||||
messages: Any,
|
||||
@@ -618,6 +621,20 @@ def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
|
||||
return file_ids
|
||||
|
||||
|
||||
|
||||
def check_is_function_call(logging_obj: "LoggingClass") -> bool:
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
is_function_call,
|
||||
)
|
||||
|
||||
if hasattr(logging_obj, "optional_params") and isinstance(
|
||||
logging_obj.optional_params, dict
|
||||
):
|
||||
if is_function_call(logging_obj.optional_params):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def filter_value_from_dict(dictionary: dict, key: str, depth: int = 0) -> Any:
|
||||
"""
|
||||
Filters a value from a dictionary
|
||||
@@ -670,3 +687,4 @@ def migrate_file_to_image_url(
|
||||
image_url_object["image_url"]["format"] = format
|
||||
return image_url_object
|
||||
|
||||
|
||||
|
||||
@@ -1359,6 +1359,7 @@ class CustomStreamWrapper:
|
||||
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||
|
||||
## CHECK FOR TOOL USE
|
||||
|
||||
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
|
||||
if self.is_function_call is True: # user passed in 'functions' param
|
||||
completion_obj["function_call"] = completion_obj["tool_calls"][0][
|
||||
@@ -1633,7 +1634,8 @@ class CustomStreamWrapper:
|
||||
if is_async_iterable(self.completion_stream):
|
||||
async for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
continue # skip None chunks
|
||||
|
||||
elif (
|
||||
self.custom_llm_provider == "gemini"
|
||||
and hasattr(chunk, "parts")
|
||||
@@ -1642,7 +1644,9 @@ class CustomStreamWrapper:
|
||||
continue
|
||||
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
|
||||
# __anext__ also calls async_success_handler, which does logging
|
||||
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||
verbose_logger.debug(
|
||||
f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}"
|
||||
)
|
||||
|
||||
processed_chunk: Optional[ModelResponseStream] = self.chunk_creator(
|
||||
chunk=chunk
|
||||
|
||||
@@ -1267,7 +1267,9 @@ async def make_call(
|
||||
)
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
streaming_response=response.aiter_lines(),
|
||||
sync_stream=False,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
@@ -1305,7 +1307,9 @@ def make_sync_call(
|
||||
)
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
streaming_response=response.iter_lines(),
|
||||
sync_stream=True,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
@@ -1726,11 +1730,19 @@ class VertexLLM(VertexBase):
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, streaming_response, sync_stream: bool):
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, logging_obj: LoggingClass
|
||||
):
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
check_is_function_call,
|
||||
)
|
||||
|
||||
self.streaming_response = streaming_response
|
||||
self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json"
|
||||
self.accumulated_json = ""
|
||||
self.sent_first_chunk = False
|
||||
self.logging_obj = logging_obj
|
||||
self.is_function_call = check_is_function_call(logging_obj)
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
@@ -1794,11 +1806,23 @@ class ModelResponseIterator:
|
||||
},
|
||||
)
|
||||
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=False,
|
||||
finish_reason=finish_reason,
|
||||
args: Dict[str, Any] = {
|
||||
"content": text or None,
|
||||
"reasoning_content": reasoning_content,
|
||||
}
|
||||
if self.is_function_call and tool_use is not None:
|
||||
args["function_call"] = tool_use["function"]
|
||||
elif tool_use is not None:
|
||||
args["tool_calls"] = [tool_use]
|
||||
|
||||
returned_chunk = ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(**args),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
index=0,
|
||||
)
|
||||
@@ -1811,7 +1835,7 @@ class ModelResponseIterator:
|
||||
self.response_iterator = self.streaming_response
|
||||
return self
|
||||
|
||||
def handle_valid_json_chunk(self, chunk: str) -> GenericStreamingChunk:
|
||||
def handle_valid_json_chunk(self, chunk: str) -> Optional[ModelResponseStream]:
|
||||
chunk = chunk.strip()
|
||||
try:
|
||||
json_chunk = json.loads(chunk)
|
||||
@@ -1829,7 +1853,9 @@ class ModelResponseIterator:
|
||||
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
|
||||
def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk:
|
||||
def handle_accumulated_json_chunk(
|
||||
self, chunk: str
|
||||
) -> Optional[ModelResponseStream]:
|
||||
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||
message = chunk.replace("\n\n", "")
|
||||
|
||||
@@ -1843,16 +1869,9 @@ class ModelResponseIterator:
|
||||
return self.chunk_parser(chunk=_data)
|
||||
except json.JSONDecodeError:
|
||||
# If it's not valid JSON yet, continue to the next event
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
return None
|
||||
|
||||
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
|
||||
def _common_chunk_parsing_logic(self, chunk: str) -> Optional[ModelResponseStream]:
|
||||
try:
|
||||
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||
if len(chunk) > 0:
|
||||
@@ -1865,15 +1884,7 @@ class ModelResponseIterator:
|
||||
return self.handle_valid_json_chunk(chunk=chunk)
|
||||
elif self.chunk_type == "accumulated_json":
|
||||
return self.handle_accumulated_json_chunk(chunk=chunk)
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
return None
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
|
||||
+1
@@ -190,6 +190,7 @@ class VertexPassthroughLoggingHandler:
|
||||
vertex_iterator = VertexModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=False,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
litellm_custom_stream_wrapper = litellm.CustomStreamWrapper(
|
||||
completion_stream=vertex_iterator,
|
||||
|
||||
@@ -712,7 +712,7 @@ async def test_acompletion_claude_2_stream():
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_completion_gemini_stream(sync_mode):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
litellm._turn_on_debug()
|
||||
print("Streaming gemini response")
|
||||
function1 = [
|
||||
{
|
||||
|
||||
+49
-6
@@ -322,6 +322,8 @@ def test_streaming_chunk_includes_reasoning_tokens():
|
||||
ModelResponseIterator,
|
||||
)
|
||||
|
||||
litellm_logging = MagicMock()
|
||||
|
||||
# Simulate a streaming chunk as would be received from Gemini
|
||||
chunk = {
|
||||
"candidates": [{"content": {"parts": [{"text": "Hello"}]}}],
|
||||
@@ -332,14 +334,54 @@ def test_streaming_chunk_includes_reasoning_tokens():
|
||||
"thoughtsTokenCount": 3,
|
||||
},
|
||||
}
|
||||
iterator = ModelResponseIterator(streaming_response=[], sync_stream=True)
|
||||
iterator = ModelResponseIterator(
|
||||
streaming_response=[], sync_stream=True, logging_obj=litellm_logging
|
||||
)
|
||||
streaming_chunk = iterator.chunk_parser(chunk)
|
||||
assert streaming_chunk["usage"] is not None
|
||||
assert streaming_chunk["usage"]["prompt_tokens"] == 5
|
||||
assert streaming_chunk["usage"]["completion_tokens"] == 7
|
||||
assert streaming_chunk["usage"]["total_tokens"] == 12
|
||||
assert streaming_chunk.usage is not None
|
||||
assert streaming_chunk.usage.prompt_tokens == 5
|
||||
assert streaming_chunk.usage.completion_tokens == 7
|
||||
assert streaming_chunk.usage.total_tokens == 12
|
||||
assert streaming_chunk.usage.completion_tokens_details.reasoning_tokens == 3
|
||||
|
||||
|
||||
def test_streaming_chunk_includes_reasoning_content():
|
||||
"""
|
||||
Ensure that when Gemini returns a chunk with `thought=True`, the parser maps it to `reasoning_content`.
|
||||
"""
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
ModelResponseIterator,
|
||||
)
|
||||
|
||||
litellm_logging = MagicMock()
|
||||
|
||||
# Simulate a streaming chunk from Gemini which contains reasoning (thought) content
|
||||
chunk = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "I'm thinking through the problem...",
|
||||
"thought": True,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"usageMetadata": {},
|
||||
}
|
||||
|
||||
iterator = ModelResponseIterator(
|
||||
streaming_response=[], sync_stream=True, logging_obj=litellm_logging
|
||||
)
|
||||
streaming_chunk = iterator.chunk_parser(chunk)
|
||||
|
||||
# The text content should be empty and reasoning_content should be populated
|
||||
assert streaming_chunk.choices[0].delta.content is None
|
||||
assert (
|
||||
streaming_chunk["usage"]["completion_tokens_details"]["reasoning_tokens"] == 3
|
||||
streaming_chunk.choices[0].delta.reasoning_content
|
||||
== "I'm thinking through the problem..."
|
||||
)
|
||||
|
||||
|
||||
@@ -393,6 +435,7 @@ def test_vertex_ai_map_thinking_param_with_budget_tokens_0():
|
||||
"thinkingBudget": 100,
|
||||
}
|
||||
|
||||
|
||||
def test_vertex_ai_map_tools():
|
||||
v = VertexGeminiConfig()
|
||||
tools = v._map_function(value=[{"code_execution": {}}])
|
||||
|
||||
Reference in New Issue
Block a user