mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-17 14:48:44 +00:00
106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
import ast
|
|
import os
|
|
|
|
|
|
class EnterpriseImportFinder(ast.NodeVisitor):
|
|
def __init__(self):
|
|
self.unsafe_imports = []
|
|
self.current_file = None
|
|
self.in_try_block = False
|
|
self.try_blocks = []
|
|
|
|
def visit_Try(self, node):
|
|
# Track that we're entering a try block
|
|
self.in_try_block = True
|
|
self.try_blocks.append(node)
|
|
# Visit all nodes in the try block
|
|
for item in node.body:
|
|
self.visit(item)
|
|
# Visit except blocks
|
|
for handler in node.handlers:
|
|
for item in handler.body:
|
|
self.visit(item)
|
|
# Visit else block if it exists
|
|
for item in node.orelse:
|
|
self.visit(item)
|
|
# Visit finally block if it exists
|
|
for item in node.finalbody:
|
|
self.visit(item)
|
|
# We're leaving the try block
|
|
self.try_blocks.pop()
|
|
self.in_try_block = len(self.try_blocks) > 0
|
|
|
|
def visit_Import(self, node):
|
|
# Check for direct imports of litellm_enterprise
|
|
for name in node.names:
|
|
if "litellm_enterprise" in name.name or "enterprise" in name.name:
|
|
if not self.in_try_block:
|
|
self.unsafe_imports.append(
|
|
{
|
|
"file": self.current_file,
|
|
"line": node.lineno,
|
|
"import": name.name,
|
|
"context": "direct import",
|
|
}
|
|
)
|
|
self.generic_visit(node)
|
|
|
|
def visit_ImportFrom(self, node):
|
|
# Check for from litellm_enterprise imports
|
|
if node.module and (
|
|
"litellm_enterprise" in node.module or "enterprise" in node.module
|
|
):
|
|
if not self.in_try_block:
|
|
self.unsafe_imports.append(
|
|
{
|
|
"file": self.current_file,
|
|
"line": node.lineno,
|
|
"import": f"from {node.module}",
|
|
"context": "from import",
|
|
}
|
|
)
|
|
self.generic_visit(node)
|
|
|
|
|
|
def find_unsafe_enterprise_imports_in_file(file_path):
|
|
with open(file_path, "r") as file:
|
|
tree = ast.parse(file.read(), filename=file_path)
|
|
finder = EnterpriseImportFinder()
|
|
finder.current_file = file_path
|
|
finder.visit(tree)
|
|
return finder.unsafe_imports
|
|
|
|
|
|
def find_unsafe_enterprise_imports_in_directory(directory):
|
|
unsafe_imports = []
|
|
for root, _, files in os.walk(directory):
|
|
for file in files:
|
|
if file.endswith(".py"):
|
|
file_path = os.path.join(root, file)
|
|
imports = find_unsafe_enterprise_imports_in_file(file_path)
|
|
if imports:
|
|
unsafe_imports.extend(imports)
|
|
return unsafe_imports
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Check for unsafe enterprise imports in the litellm directory
|
|
directory_path = "./litellm"
|
|
unsafe_imports = find_unsafe_enterprise_imports_in_directory(directory_path)
|
|
|
|
if unsafe_imports:
|
|
print("🚨 UNSAFE ENTERPRISE IMPORTS FOUND (not in try-except blocks):")
|
|
for imp in unsafe_imports:
|
|
print(f"File: {imp['file']}")
|
|
print(f"Line: {imp['line']}")
|
|
print(f"Import: {imp['import']}")
|
|
print(f"Context: {imp['context']}")
|
|
print("---")
|
|
|
|
# Raise exception to fail CI/CD
|
|
raise Exception(
|
|
"🚨 Unsafe enterprise imports found. All enterprise imports must be wrapped in try-except blocks."
|
|
)
|
|
else:
|
|
print("✅ No unsafe enterprise imports found.")
|