diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 2c7af8d5ba..5d5a8bf256 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -637,7 +637,10 @@ class CustomStreamWrapper: if isinstance(chunk, bytes): chunk = chunk.decode("utf-8") if "text_output" in chunk: - response = chunk.replace("data: ", "").strip() + response = ( + CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" + ) + response = response.strip() parsed_response = json.loads(response) else: return { @@ -1828,6 +1831,42 @@ class CustomStreamWrapper: extra_kwargs={}, ) + @staticmethod + def _strip_sse_data_from_chunk(chunk: Optional[str]) -> Optional[str]: + """ + Strips the 'data: ' prefix from Server-Sent Events (SSE) chunks. + + Some providers like sagemaker send it as `data:`, need to handle both + + SSE messages are prefixed with 'data: ' which is part of the protocol, + not the actual content from the LLM. This method removes that prefix + and returns the actual content. + + Args: + chunk: The SSE chunk that may contain the 'data: ' prefix (string or bytes) + + Returns: + The chunk with the 'data: ' prefix removed, or the original chunk + if no prefix was found. Returns None if input is None. + + See OpenAI Python Ref for this: https://github.com/openai/openai-python/blob/041bf5a8ec54da19aad0169671793c2078bd6173/openai/api_requestor.py#L100 + """ + if chunk is None: + return None + + if isinstance(chunk, str): + # OpenAI sends `data: ` + if chunk.startswith("data: "): + # Strip the prefix and any leading whitespace that might follow it + _length_of_sse_data_prefix = len("data: ") + return chunk[_length_of_sse_data_prefix:] + elif chunk.startswith("data:"): + # Sagemaker sends `data:`, no trailing whitespace + _length_of_sse_data_prefix = len("data:") + return chunk[_length_of_sse_data_prefix:] + + return chunk + def calculate_total_usage(chunks: List[ModelResponse]) -> Usage: """Assume most recent usage chunk has total usage uptil then.""" diff --git a/litellm/llms/codestral/completion/transformation.py b/litellm/llms/codestral/completion/transformation.py index 84551cd553..5955e91deb 100644 --- a/litellm/llms/codestral/completion/transformation.py +++ b/litellm/llms/codestral/completion/transformation.py @@ -84,7 +84,9 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig): finish_reason = None logprobs = None - chunk_data = chunk_data.replace("data:", "") + chunk_data = ( + litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk_data) or "" + ) chunk_data = chunk_data.strip() if len(chunk_data) == 0 or chunk_data == "[DONE]": return { diff --git a/litellm/llms/databricks/streaming_utils.py b/litellm/llms/databricks/streaming_utils.py index 0deaa06988..2db53df908 100644 --- a/litellm/llms/databricks/streaming_utils.py +++ b/litellm/llms/databricks/streaming_utils.py @@ -89,7 +89,7 @@ class ModelResponseIterator: raise RuntimeError(f"Error receiving chunk from stream: {e}") try: - chunk = chunk.replace("data:", "") + chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" chunk = chunk.strip() if len(chunk) > 0: json_chunk = json.loads(chunk) @@ -134,7 +134,7 @@ class ModelResponseIterator: raise RuntimeError(f"Error receiving chunk from stream: {e}") try: - chunk = chunk.replace("data:", "") + chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" chunk = chunk.strip() if chunk == "[DONE]": raise StopAsyncIteration diff --git a/litellm/llms/sagemaker/common_utils.py b/litellm/llms/sagemaker/common_utils.py index 49e4989ff1..9884f420c3 100644 --- a/litellm/llms/sagemaker/common_utils.py +++ b/litellm/llms/sagemaker/common_utils.py @@ -3,6 +3,7 @@ from typing import AsyncIterator, Iterator, List, Optional, Union import httpx +import litellm from litellm import verbose_logger from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.types.utils import GenericStreamingChunk as GChunk @@ -78,7 +79,11 @@ class AWSEventStreamDecoder: message = self._parse_message_from_event(event) if message: # remove data: prefix and "\n\n" at the end - message = message.replace("data:", "").replace("\n\n", "") + message = ( + litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message) + or "" + ) + message = message.replace("\n\n", "") # Accumulate JSON data accumulated_json += message @@ -127,7 +132,11 @@ class AWSEventStreamDecoder: if message: verbose_logger.debug("sagemaker parsed chunk bytes %s", message) # remove data: prefix and "\n\n" at the end - message = message.replace("data:", "").replace("\n\n", "") + message = ( + litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message) + or "" + ) + message = message.replace("\n\n", "") # Accumulate JSON data accumulated_json += message diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 20f9c9da47..294939a3c5 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -1408,7 +1408,8 @@ class ModelResponseIterator: return self.chunk_parser(chunk=json_chunk) def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk: - message = chunk.replace("data:", "").replace("\n\n", "") + chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" + message = chunk.replace("\n\n", "") # Accumulate JSON data self.accumulated_json += message @@ -1431,7 +1432,7 @@ class ModelResponseIterator: def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk: try: - chunk = chunk.replace("data:", "") + chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" if len(chunk) > 0: """ Check if initial chunk valid json diff --git a/tests/code_coverage_tests/check_data_replace_usage.py b/tests/code_coverage_tests/check_data_replace_usage.py new file mode 100644 index 0000000000..088e6e29b3 --- /dev/null +++ b/tests/code_coverage_tests/check_data_replace_usage.py @@ -0,0 +1,133 @@ +import os +import re +import ast +from pathlib import Path + + +class DataReplaceVisitor(ast.NodeVisitor): + """AST visitor that finds calls to .replace("data:", ...) in the code.""" + + def __init__(self): + self.issues = [] + self.current_file = None + + def set_file(self, filename): + self.current_file = filename + + def visit_Call(self, node): + # Check for method calls like x.replace(...) + if isinstance(node.func, ast.Attribute) and node.func.attr == "replace": + # Check if first argument is "data:" + if ( + len(node.args) >= 2 + and isinstance(node.args[0], ast.Constant) + and isinstance(node.args[0].value, str) + and "data:" in node.args[0].value + ): + + self.issues.append( + { + "file": self.current_file, + "line": node.lineno, + "col": node.col_offset, + "text": f'Found .replace("data:", ...) at line {node.lineno}', + } + ) + + # Continue visiting child nodes + self.generic_visit(node) + + +def check_file_with_ast(file_path): + """Check a Python file for .replace("data:", ...) using AST parsing.""" + with open(file_path, "r", encoding="utf-8") as f: + try: + tree = ast.parse(f.read(), filename=file_path) + visitor = DataReplaceVisitor() + visitor.set_file(file_path) + visitor.visit(tree) + return visitor.issues + except SyntaxError: + return [ + { + "file": file_path, + "line": 0, + "col": 0, + "text": f"Syntax error in file, could not parse", + } + ] + + +def check_file_with_regex(file_path): + """Check any file for .replace("data:", ...) using regex.""" + issues = [] + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + for i, line in enumerate(f, 1): + matches = re.finditer(r'\.replace\(\s*[\'"]data:[\'"]', line) + for match in matches: + issues.append( + { + "file": file_path, + "line": i, + "col": match.start(), + "text": f'Found .replace("data:", ...) at line {i}', + } + ) + return issues + + +def scan_directory(base_dir): + """Scan a directory recursively for files containing .replace("data:", ...).""" + all_issues = [] + + for root, _, files in os.walk(base_dir): + for file in files: + print("checking file: ", file) + file_path = os.path.join(root, file) + + # Skip directories we don't want to check + if any( + d in file_path for d in [".git", "__pycache__", ".venv", "node_modules"] + ): + continue + + # For Python files, use AST for more accurate parsing + if file.endswith(".py"): + issues = check_file_with_ast(file_path) + # For other files that might contain code, use regex + elif file.endswith((".js", ".ts", ".jsx", ".tsx", ".md", ".ipynb")): + issues = check_file_with_regex(file_path) + else: + continue + + all_issues.extend(issues) + + return all_issues + + +def main(): + # Start from the project root directory + + base_dir = "./litellm" + + # Local testing + # base_dir = "../../litellm" + + print(f"Scanning for .replace('data:', ...) usage in {base_dir}") + issues = scan_directory(base_dir) + + if issues: + print(f"\n⚠️ Found {len(issues)} instances of .replace('data:', ...):") + for issue in issues: + print(f"{issue['file']}:{issue['line']} - {issue['text']}") + + # Fail the test if issues are found + raise Exception( + f"Found {len(issues)} instances of .replace('data:', ...) which may be unsafe. Use litellm.CustomStreamWrapper._strip_sse_data_from_chunk instead." + ) + else: + print("✅ No instances of .replace('data:', ...) found.") + + +if __name__ == "__main__": + main() diff --git a/tests/litellm/litellm_core_utils/test_streaming_handler.py b/tests/litellm/litellm_core_utils/test_streaming_handler.py index 54d178e3ac..10fe1db4ab 100644 --- a/tests/litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/litellm/litellm_core_utils/test_streaming_handler.py @@ -256,3 +256,24 @@ def test_multi_chunk_reasoning_and_content( # Verify final state assert initialized_custom_stream_wrapper.sent_first_thinking_block is True assert initialized_custom_stream_wrapper.sent_last_thinking_block is True + + +def test_strip_sse_data_from_chunk(): + """Test the static method that strips 'data: ' prefix from SSE chunks""" + # Test with string inputs + assert CustomStreamWrapper._strip_sse_data_from_chunk("data: content") == "content" + assert ( + CustomStreamWrapper._strip_sse_data_from_chunk("data: spaced content") + == " spaced content" + ) + assert ( + CustomStreamWrapper._strip_sse_data_from_chunk("regular content") + == "regular content" + ) + assert ( + CustomStreamWrapper._strip_sse_data_from_chunk("regular content with data:") + == "regular content with data:" + ) + + # Test with None input + assert CustomStreamWrapper._strip_sse_data_from_chunk(None) is None