Files
litellm/tests/code_coverage_tests/ban_copy_deepcopy_kwargs.py
T
Krish Dholakia c7e4435bdc Fix - using managed files w/ OTEL + UI - add model group alias on UI (#13171)
* fix(router.py): safe deep copy kwargs

OTEL adds a parent_otel_span which cannot be deepcopied

* fix: use safe deep copy in other places as well

* test: add script to check and ban copy.deepcopy of kwargs

enforce safe_deep_copy usage

* build(ui/): new component for adding model group alias on UI

* fix(proxy_server.py): support updating model_group_alias via /config/update

allows ui component to work

* fix(router.py): update model_group_alias in router settings based on db value

* fix: fix code qa error
2025-07-31 21:22:04 -07:00

143 lines
5.3 KiB
Python

import ast
import os
class CopyDeepcopyKwargsDetector(ast.NodeVisitor):
def __init__(self):
self.violations = []
def visit_Call(self, node):
# Check if this is a copy.deepcopy call
if self._is_copy_deepcopy_call(node):
# Check if any argument contains 'kwargs' in its name
for arg in node.args:
if self._is_kwargs_related(arg):
# Get line number and argument name for reporting
arg_name = self._get_arg_name(arg)
self.violations.append(
{
"line": node.lineno,
"arg_name": arg_name,
"full_call": (
ast.unparse(node)
if hasattr(ast, "unparse")
else str(node)
),
}
)
self.generic_visit(node)
def _is_copy_deepcopy_call(self, node):
"""Check if this is a copy.deepcopy() call"""
if isinstance(node.func, ast.Attribute):
# Case: copy.deepcopy()
if (
isinstance(node.func.value, ast.Name)
and node.func.value.id == "copy"
and node.func.attr == "deepcopy"
):
return True
elif isinstance(node.func, ast.Name):
# Case: deepcopy() (if imported as 'from copy import deepcopy')
if node.func.id == "deepcopy":
return True
return False
def _is_kwargs_related(self, arg):
"""Check if the argument is kwargs-related"""
if isinstance(arg, ast.Name):
# Direct variable names containing 'kwargs'
return "kwargs" in arg.id.lower()
elif isinstance(arg, ast.Subscript):
# Handle cases like kwargs['key']
if isinstance(arg.value, ast.Name):
return "kwargs" in arg.value.id.lower()
elif isinstance(arg, ast.Attribute):
# Handle cases like self.kwargs
return "kwargs" in arg.attr.lower()
return False
def _get_arg_name(self, arg):
"""Get a readable name for the argument"""
if isinstance(arg, ast.Name):
return arg.id
elif isinstance(arg, ast.Subscript) and isinstance(arg.value, ast.Name):
return f"{arg.value.id}[...]"
elif isinstance(arg, ast.Attribute):
return f"...{arg.attr}"
else:
return "unknown_kwargs_variable"
def find_copy_deepcopy_kwargs_in_file(file_path):
"""Find copy.deepcopy usage with kwargs in a single file"""
try:
with open(file_path, "r", encoding="utf-8") as file:
tree = ast.parse(file.read(), filename=file_path)
detector = CopyDeepcopyKwargsDetector()
detector.visit(tree)
return detector.violations
except Exception as e:
print(f"Error parsing {file_path}: {e}")
return []
def find_copy_deepcopy_kwargs_in_directory(directory):
"""Find copy.deepcopy usage with kwargs in all Python files in directory"""
violations = {}
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
print(f"Checking file: {file_path}")
file_violations = find_copy_deepcopy_kwargs_in_file(file_path)
if file_violations:
violations[file_path] = file_violations
return violations
if __name__ == "__main__":
# Check for copy.deepcopy(kwargs) usage in the litellm directory
directory_path = "./litellm"
violations = find_copy_deepcopy_kwargs_in_directory(directory_path)
print("\n" + "=" * 80)
print("COPY.DEEPCOPY KWARGS VIOLATIONS FOUND:")
print("=" * 80)
if violations:
total_violations = 0
for file_path, file_violations in violations.items():
print(f"\n📁 File: {file_path}")
for violation in file_violations:
total_violations += 1
print(
f" ❌ Line {violation['line']}: copy.deepcopy({violation['arg_name']})"
)
print(f" Full call: {violation['full_call']}")
print(f"\n{'='*80}")
print(f"🚨 TOTAL VIOLATIONS: {total_violations}")
print("🚨 USE safe_deep_copy() INSTEAD OF copy.deepcopy() FOR KWARGS!")
print("🚨 Available imports:")
print(" - from litellm.proxy.utils import safe_deep_copy")
print(" - from litellm.litellm_core_utils.core_helpers import safe_deep_copy")
print("=" * 80)
# Get first violation for the exception message
first_file = list(violations.keys())[0]
first_violation = violations[first_file][0]
raise Exception(
f"🚨 Found {total_violations} copy.deepcopy(kwargs) violations! "
f"First violation: {first_file}:{first_violation['line']} - "
f"copy.deepcopy({first_violation['arg_name']}). "
f"Use safe_deep_copy() instead to handle non-serializable objects like OTEL spans."
)
else:
print("✅ No copy.deepcopy(kwargs) violations found!")
print("✅ All kwargs copying appears to use safe_deep_copy() correctly.")