mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 22:48:35 +00:00
Merge pull request #24802 from BerriAI/litellm_/modest-easley
[Refactor] Extract helper methods in guardrail handlers to fix PLR0915
This commit is contained in:
@@ -234,37 +234,13 @@ class A2AGuardrailHandler(BaseTranslation):
|
||||
then the combined guardrailed text is written into the first chunk that had text
|
||||
and all other text parts in other chunks are cleared (in-place).
|
||||
"""
|
||||
from litellm.llms.a2a.common_utils import extract_text_from_a2a_response
|
||||
|
||||
# Parse each item; keep alignment with responses_so_far (None where unparseable)
|
||||
parsed: List[Optional[Dict[str, Any]]] = [None] * len(responses_so_far)
|
||||
for i, item in enumerate(responses_so_far):
|
||||
if isinstance(item, dict):
|
||||
obj = item
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
obj = json.loads(item.strip())
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
if isinstance(obj.get("result"), dict):
|
||||
parsed[i] = obj
|
||||
|
||||
valid_parsed = [(i, obj) for i, obj in enumerate(parsed) if obj is not None]
|
||||
parsed, valid_parsed = self._parse_streaming_responses(responses_so_far)
|
||||
if not valid_parsed:
|
||||
return responses_so_far
|
||||
|
||||
# Collect text from each chunk in order (by original index in responses_so_far)
|
||||
text_parts: List[str] = []
|
||||
chunk_indices_with_text: List[int] = [] # indices into valid_parsed
|
||||
for idx, (orig_i, obj) in enumerate(valid_parsed):
|
||||
t = extract_text_from_a2a_response(obj)
|
||||
if t:
|
||||
text_parts.append(t)
|
||||
chunk_indices_with_text.append(orig_i)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_text, chunk_indices_with_text = (
|
||||
self._collect_text_from_parsed_chunks(valid_parsed)
|
||||
)
|
||||
if not combined_text:
|
||||
return responses_so_far
|
||||
|
||||
@@ -337,6 +313,45 @@ class A2AGuardrailHandler(BaseTranslation):
|
||||
|
||||
return responses_so_far
|
||||
|
||||
def _parse_streaming_responses(
|
||||
self,
|
||||
responses_so_far: List[Any],
|
||||
) -> Tuple[
|
||||
List[Optional[Dict[str, Any]]], List[Tuple[int, Dict[str, Any]]]
|
||||
]:
|
||||
"""Parse JSON-RPC items, returning aligned parsed list and valid entries."""
|
||||
parsed: List[Optional[Dict[str, Any]]] = [None] * len(responses_so_far)
|
||||
for i, item in enumerate(responses_so_far):
|
||||
if isinstance(item, dict):
|
||||
obj = item
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
obj = json.loads(item.strip())
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
if isinstance(obj.get("result"), dict):
|
||||
parsed[i] = obj
|
||||
valid_parsed = [(i, obj) for i, obj in enumerate(parsed) if obj is not None]
|
||||
return parsed, valid_parsed
|
||||
|
||||
def _collect_text_from_parsed_chunks(
|
||||
self,
|
||||
valid_parsed: List[Tuple[int, Dict[str, Any]]],
|
||||
) -> Tuple[str, List[int]]:
|
||||
"""Collect text from parsed chunks, returning combined text and indices."""
|
||||
from litellm.llms.a2a.common_utils import extract_text_from_a2a_response
|
||||
|
||||
text_parts: List[str] = []
|
||||
chunk_indices_with_text: List[int] = []
|
||||
for _idx, (orig_i, obj) in enumerate(valid_parsed):
|
||||
t = extract_text_from_a2a_response(obj)
|
||||
if t:
|
||||
text_parts.append(t)
|
||||
chunk_indices_with_text.append(orig_i)
|
||||
return "".join(text_parts), chunk_indices_with_text
|
||||
|
||||
def _extract_texts_from_result(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
|
||||
@@ -277,82 +277,26 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
images_to_check: List[str] = []
|
||||
tool_calls_to_check: List[ChatCompletionToolCallChunk] = []
|
||||
task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
# Track (content_index, None) for each text
|
||||
|
||||
# Handle both dict and object responses
|
||||
response_content: List[Any] = []
|
||||
if isinstance(response, dict):
|
||||
response_content = response.get("content", []) or []
|
||||
elif hasattr(response, "content"):
|
||||
content = getattr(response, "content", None)
|
||||
response_content = content or []
|
||||
else:
|
||||
response_content = []
|
||||
|
||||
response_content = self._get_response_content(response)
|
||||
if not response_content:
|
||||
return response
|
||||
|
||||
# Step 1: Extract all text content and tool calls from response
|
||||
for content_idx, content_block in enumerate(response_content):
|
||||
# Handle both dict and Pydantic object content blocks
|
||||
block_dict: Dict[str, Any] = {}
|
||||
if isinstance(content_block, dict):
|
||||
block_type = content_block.get("type")
|
||||
block_dict = cast(Dict[str, Any], content_block)
|
||||
elif hasattr(content_block, "type"):
|
||||
block_type = getattr(content_block, "type", None)
|
||||
# Convert Pydantic object to dict for processing
|
||||
if hasattr(content_block, "model_dump"):
|
||||
block_dict = content_block.model_dump()
|
||||
else:
|
||||
block_dict = {
|
||||
"type": block_type,
|
||||
"text": getattr(content_block, "text", None),
|
||||
}
|
||||
else:
|
||||
continue
|
||||
|
||||
if block_type in ["text", "tool_use"]:
|
||||
self._extract_output_text_and_images(
|
||||
content_block=block_dict,
|
||||
content_idx=content_idx,
|
||||
texts_to_check=texts_to_check,
|
||||
images_to_check=images_to_check,
|
||||
task_mappings=task_mappings,
|
||||
tool_calls_to_check=tool_calls_to_check,
|
||||
)
|
||||
self._extract_from_content_blocks(
|
||||
response_content, texts_to_check, images_to_check,
|
||||
task_mappings, tool_calls_to_check,
|
||||
)
|
||||
|
||||
# Step 2: Apply guardrail to all texts in batch
|
||||
if texts_to_check or tool_calls_to_check:
|
||||
# Use the real request_data if provided (proxy path), otherwise
|
||||
# create a standalone dict (SDK / direct-call path).
|
||||
if request_data is None:
|
||||
request_data = {"response": response}
|
||||
else:
|
||||
if "response" not in request_data:
|
||||
request_data["response"] = response
|
||||
request_data = self._prepare_request_data(
|
||||
request_data, response, user_api_key_dict, key="response",
|
||||
)
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
if "litellm_metadata" not in request_data:
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
if tool_calls_to_check:
|
||||
inputs["tool_calls"] = tool_calls_to_check
|
||||
# Include model information from the response if available
|
||||
response_model = None
|
||||
if isinstance(response, dict):
|
||||
response_model = response.get("model")
|
||||
elif hasattr(response, "model"):
|
||||
response_model = getattr(response, "model", None)
|
||||
if response_model:
|
||||
inputs["model"] = response_model
|
||||
inputs = self._build_guardrail_inputs(
|
||||
texts_to_check, images_to_check, tool_calls_to_check, response,
|
||||
)
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
@@ -440,6 +384,95 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
)
|
||||
return responses_so_far
|
||||
|
||||
def _prepare_request_data(
|
||||
self,
|
||||
request_data: Optional[dict],
|
||||
response: Any,
|
||||
user_api_key_dict: Optional[Any],
|
||||
key: str,
|
||||
) -> dict:
|
||||
"""Ensure request_data has the response/responses_so_far key and metadata."""
|
||||
if request_data is None:
|
||||
request_data = {key: response}
|
||||
else:
|
||||
if key not in request_data:
|
||||
request_data[key] = response
|
||||
|
||||
if "litellm_metadata" not in request_data:
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
return request_data
|
||||
|
||||
@staticmethod
|
||||
def _get_response_content(response: Any) -> List[Any]:
|
||||
"""Extract content list from a dict or object response."""
|
||||
if isinstance(response, dict):
|
||||
return response.get("content", []) or []
|
||||
elif hasattr(response, "content"):
|
||||
return getattr(response, "content", None) or []
|
||||
return []
|
||||
|
||||
def _extract_from_content_blocks(
|
||||
self,
|
||||
response_content: List[Any],
|
||||
texts_to_check: List[str],
|
||||
images_to_check: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
tool_calls_to_check: List["ChatCompletionToolCallChunk"],
|
||||
) -> None:
|
||||
"""Extract text, images, and tool calls from content blocks."""
|
||||
for content_idx, content_block in enumerate(response_content):
|
||||
block_dict: Dict[str, Any] = {}
|
||||
if isinstance(content_block, dict):
|
||||
block_type = content_block.get("type")
|
||||
block_dict = cast(Dict[str, Any], content_block)
|
||||
elif hasattr(content_block, "type"):
|
||||
block_type = getattr(content_block, "type", None)
|
||||
if hasattr(content_block, "model_dump"):
|
||||
block_dict = content_block.model_dump()
|
||||
else:
|
||||
block_dict = {
|
||||
"type": block_type,
|
||||
"text": getattr(content_block, "text", None),
|
||||
}
|
||||
else:
|
||||
continue
|
||||
|
||||
if block_type in ["text", "tool_use"]:
|
||||
self._extract_output_text_and_images(
|
||||
content_block=block_dict,
|
||||
content_idx=content_idx,
|
||||
texts_to_check=texts_to_check,
|
||||
images_to_check=images_to_check,
|
||||
task_mappings=task_mappings,
|
||||
tool_calls_to_check=tool_calls_to_check,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_guardrail_inputs(
|
||||
texts_to_check: List[str],
|
||||
images_to_check: List[str],
|
||||
tool_calls_to_check: List["ChatCompletionToolCallChunk"],
|
||||
response: Any,
|
||||
) -> "GenericGuardrailAPIInputs":
|
||||
"""Build GenericGuardrailAPIInputs with optional images, tool calls, model."""
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
if tool_calls_to_check:
|
||||
inputs["tool_calls"] = tool_calls_to_check
|
||||
response_model = None
|
||||
if isinstance(response, dict):
|
||||
response_model = response.get("model")
|
||||
elif hasattr(response, "model"):
|
||||
response_model = getattr(response, "model", None)
|
||||
if response_model:
|
||||
inputs["model"] = response_model
|
||||
return inputs
|
||||
|
||||
def get_streaming_string_so_far(self, responses_so_far: List[Any]) -> str:
|
||||
"""
|
||||
Parse streaming responses and extract accumulated text content.
|
||||
|
||||
Reference in New Issue
Block a user