Merge pull request #24802 from BerriAI/litellm_/modest-easley

[Refactor] Extract helper methods in guardrail handlers to fix PLR0915
This commit is contained in:
yuneng-jiang
2026-03-30 11:42:12 -07:00
committed by GitHub
2 changed files with 143 additions and 95 deletions
@@ -234,37 +234,13 @@ class A2AGuardrailHandler(BaseTranslation):
then the combined guardrailed text is written into the first chunk that had text
and all other text parts in other chunks are cleared (in-place).
"""
from litellm.llms.a2a.common_utils import extract_text_from_a2a_response
# Parse each item; keep alignment with responses_so_far (None where unparseable)
parsed: List[Optional[Dict[str, Any]]] = [None] * len(responses_so_far)
for i, item in enumerate(responses_so_far):
if isinstance(item, dict):
obj = item
elif isinstance(item, str):
try:
obj = json.loads(item.strip())
except (json.JSONDecodeError, TypeError):
continue
else:
continue
if isinstance(obj.get("result"), dict):
parsed[i] = obj
valid_parsed = [(i, obj) for i, obj in enumerate(parsed) if obj is not None]
parsed, valid_parsed = self._parse_streaming_responses(responses_so_far)
if not valid_parsed:
return responses_so_far
# Collect text from each chunk in order (by original index in responses_so_far)
text_parts: List[str] = []
chunk_indices_with_text: List[int] = [] # indices into valid_parsed
for idx, (orig_i, obj) in enumerate(valid_parsed):
t = extract_text_from_a2a_response(obj)
if t:
text_parts.append(t)
chunk_indices_with_text.append(orig_i)
combined_text = "".join(text_parts)
combined_text, chunk_indices_with_text = (
self._collect_text_from_parsed_chunks(valid_parsed)
)
if not combined_text:
return responses_so_far
@@ -337,6 +313,45 @@ class A2AGuardrailHandler(BaseTranslation):
return responses_so_far
def _parse_streaming_responses(
self,
responses_so_far: List[Any],
) -> Tuple[
List[Optional[Dict[str, Any]]], List[Tuple[int, Dict[str, Any]]]
]:
"""Parse JSON-RPC items, returning aligned parsed list and valid entries."""
parsed: List[Optional[Dict[str, Any]]] = [None] * len(responses_so_far)
for i, item in enumerate(responses_so_far):
if isinstance(item, dict):
obj = item
elif isinstance(item, str):
try:
obj = json.loads(item.strip())
except (json.JSONDecodeError, TypeError):
continue
else:
continue
if isinstance(obj.get("result"), dict):
parsed[i] = obj
valid_parsed = [(i, obj) for i, obj in enumerate(parsed) if obj is not None]
return parsed, valid_parsed
def _collect_text_from_parsed_chunks(
self,
valid_parsed: List[Tuple[int, Dict[str, Any]]],
) -> Tuple[str, List[int]]:
"""Collect text from parsed chunks, returning combined text and indices."""
from litellm.llms.a2a.common_utils import extract_text_from_a2a_response
text_parts: List[str] = []
chunk_indices_with_text: List[int] = []
for _idx, (orig_i, obj) in enumerate(valid_parsed):
t = extract_text_from_a2a_response(obj)
if t:
text_parts.append(t)
chunk_indices_with_text.append(orig_i)
return "".join(text_parts), chunk_indices_with_text
def _extract_texts_from_result(
self,
result: Dict[str, Any],
@@ -277,82 +277,26 @@ class AnthropicMessagesHandler(BaseTranslation):
images_to_check: List[str] = []
tool_calls_to_check: List[ChatCompletionToolCallChunk] = []
task_mappings: List[Tuple[int, Optional[int]]] = []
# Track (content_index, None) for each text
# Handle both dict and object responses
response_content: List[Any] = []
if isinstance(response, dict):
response_content = response.get("content", []) or []
elif hasattr(response, "content"):
content = getattr(response, "content", None)
response_content = content or []
else:
response_content = []
response_content = self._get_response_content(response)
if not response_content:
return response
# Step 1: Extract all text content and tool calls from response
for content_idx, content_block in enumerate(response_content):
# Handle both dict and Pydantic object content blocks
block_dict: Dict[str, Any] = {}
if isinstance(content_block, dict):
block_type = content_block.get("type")
block_dict = cast(Dict[str, Any], content_block)
elif hasattr(content_block, "type"):
block_type = getattr(content_block, "type", None)
# Convert Pydantic object to dict for processing
if hasattr(content_block, "model_dump"):
block_dict = content_block.model_dump()
else:
block_dict = {
"type": block_type,
"text": getattr(content_block, "text", None),
}
else:
continue
if block_type in ["text", "tool_use"]:
self._extract_output_text_and_images(
content_block=block_dict,
content_idx=content_idx,
texts_to_check=texts_to_check,
images_to_check=images_to_check,
task_mappings=task_mappings,
tool_calls_to_check=tool_calls_to_check,
)
self._extract_from_content_blocks(
response_content, texts_to_check, images_to_check,
task_mappings, tool_calls_to_check,
)
# Step 2: Apply guardrail to all texts in batch
if texts_to_check or tool_calls_to_check:
# Use the real request_data if provided (proxy path), otherwise
# create a standalone dict (SDK / direct-call path).
if request_data is None:
request_data = {"response": response}
else:
if "response" not in request_data:
request_data["response"] = response
request_data = self._prepare_request_data(
request_data, response, user_api_key_dict, key="response",
)
# Add user API key metadata with prefixed keys
if "litellm_metadata" not in request_data:
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
inputs["images"] = images_to_check
if tool_calls_to_check:
inputs["tool_calls"] = tool_calls_to_check
# Include model information from the response if available
response_model = None
if isinstance(response, dict):
response_model = response.get("model")
elif hasattr(response, "model"):
response_model = getattr(response, "model", None)
if response_model:
inputs["model"] = response_model
inputs = self._build_guardrail_inputs(
texts_to_check, images_to_check, tool_calls_to_check, response,
)
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
@@ -440,6 +384,95 @@ class AnthropicMessagesHandler(BaseTranslation):
)
return responses_so_far
def _prepare_request_data(
self,
request_data: Optional[dict],
response: Any,
user_api_key_dict: Optional[Any],
key: str,
) -> dict:
"""Ensure request_data has the response/responses_so_far key and metadata."""
if request_data is None:
request_data = {key: response}
else:
if key not in request_data:
request_data[key] = response
if "litellm_metadata" not in request_data:
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
return request_data
@staticmethod
def _get_response_content(response: Any) -> List[Any]:
"""Extract content list from a dict or object response."""
if isinstance(response, dict):
return response.get("content", []) or []
elif hasattr(response, "content"):
return getattr(response, "content", None) or []
return []
def _extract_from_content_blocks(
self,
response_content: List[Any],
texts_to_check: List[str],
images_to_check: List[str],
task_mappings: List[Tuple[int, Optional[int]]],
tool_calls_to_check: List["ChatCompletionToolCallChunk"],
) -> None:
"""Extract text, images, and tool calls from content blocks."""
for content_idx, content_block in enumerate(response_content):
block_dict: Dict[str, Any] = {}
if isinstance(content_block, dict):
block_type = content_block.get("type")
block_dict = cast(Dict[str, Any], content_block)
elif hasattr(content_block, "type"):
block_type = getattr(content_block, "type", None)
if hasattr(content_block, "model_dump"):
block_dict = content_block.model_dump()
else:
block_dict = {
"type": block_type,
"text": getattr(content_block, "text", None),
}
else:
continue
if block_type in ["text", "tool_use"]:
self._extract_output_text_and_images(
content_block=block_dict,
content_idx=content_idx,
texts_to_check=texts_to_check,
images_to_check=images_to_check,
task_mappings=task_mappings,
tool_calls_to_check=tool_calls_to_check,
)
@staticmethod
def _build_guardrail_inputs(
texts_to_check: List[str],
images_to_check: List[str],
tool_calls_to_check: List["ChatCompletionToolCallChunk"],
response: Any,
) -> "GenericGuardrailAPIInputs":
"""Build GenericGuardrailAPIInputs with optional images, tool calls, model."""
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
inputs["images"] = images_to_check
if tool_calls_to_check:
inputs["tool_calls"] = tool_calls_to_check
response_model = None
if isinstance(response, dict):
response_model = response.get("model")
elif hasattr(response, "model"):
response_model = getattr(response, "model", None)
if response_model:
inputs["model"] = response_model
return inputs
def get_streaming_string_so_far(self, responses_so_far: List[Any]) -> str:
"""
Parse streaming responses and extract accumulated text content.