Consolidate and fix forbidden import pre-commit checks (#33982)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -121,24 +121,9 @@ repos:
|
|||||||
name: Update Dockerfile dependency graph
|
name: Update Dockerfile dependency graph
|
||||||
entry: tools/pre_commit/update-dockerfile-graph.sh
|
entry: tools/pre_commit/update-dockerfile-graph.sh
|
||||||
language: script
|
language: script
|
||||||
- id: enforce-import-regex-instead-of-re
|
- id: check-forbidden-imports
|
||||||
name: Enforce import regex as re
|
name: Check for forbidden imports
|
||||||
entry: python tools/pre_commit/enforce_regex_import.py
|
entry: python tools/pre_commit/check_forbidden_imports.py
|
||||||
language: python
|
|
||||||
types: [python]
|
|
||||||
pass_filenames: false
|
|
||||||
additional_dependencies: [regex]
|
|
||||||
# forbid directly import triton
|
|
||||||
- id: forbid-direct-triton-import
|
|
||||||
name: "Forbid direct 'import triton'"
|
|
||||||
entry: python tools/pre_commit/check_triton_import.py
|
|
||||||
language: python
|
|
||||||
types: [python]
|
|
||||||
pass_filenames: false
|
|
||||||
additional_dependencies: [regex]
|
|
||||||
- id: check-pickle-imports
|
|
||||||
name: Prevent new pickle/cloudpickle imports
|
|
||||||
entry: python tools/pre_commit/check_pickle_imports.py
|
|
||||||
language: python
|
language: python
|
||||||
types: [python]
|
types: [python]
|
||||||
additional_dependencies: [regex]
|
additional_dependencies: [regex]
|
||||||
|
|||||||
142
tools/pre_commit/check_forbidden_imports.py
Normal file
142
tools/pre_commit/check_forbidden_imports.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ForbiddenImport:
|
||||||
|
pattern: str
|
||||||
|
tip: str
|
||||||
|
allowed_pattern: re.Pattern = re.compile(r"^$") # matches nothing by default
|
||||||
|
allowed_files: set[str] = field(default_factory=set)
|
||||||
|
|
||||||
|
|
||||||
|
CHECK_IMPORTS = {
|
||||||
|
"pickle/cloudpickle": ForbiddenImport(
|
||||||
|
pattern=(
|
||||||
|
r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
|
||||||
|
r"|from\s+(pickle|cloudpickle)\s+import\b)"
|
||||||
|
),
|
||||||
|
tip=(
|
||||||
|
"Avoid using pickle or cloudpickle or add this file to "
|
||||||
|
"tools/pre_commit/check_forbidden_imports.py."
|
||||||
|
),
|
||||||
|
allowed_files={
|
||||||
|
# pickle
|
||||||
|
"vllm/multimodal/hasher.py",
|
||||||
|
"vllm/transformers_utils/config.py",
|
||||||
|
"vllm/model_executor/models/registry.py",
|
||||||
|
"vllm/compilation/caching.py",
|
||||||
|
"vllm/compilation/piecewise_backend.py",
|
||||||
|
"vllm/distributed/utils.py",
|
||||||
|
"vllm/distributed/parallel_state.py",
|
||||||
|
"vllm/distributed/device_communicators/all_reduce_utils.py",
|
||||||
|
"vllm/distributed/device_communicators/shm_broadcast.py",
|
||||||
|
"vllm/distributed/device_communicators/shm_object_storage.py",
|
||||||
|
"vllm/utils/hashing.py",
|
||||||
|
"tests/multimodal/media/test_base.py",
|
||||||
|
"tests/tokenizers_/test_hf.py",
|
||||||
|
"tests/utils_/test_hashing.py",
|
||||||
|
"tests/compile/test_aot_compile.py",
|
||||||
|
"benchmarks/kernels/graph_machete_bench.py",
|
||||||
|
"benchmarks/kernels/benchmark_lora.py",
|
||||||
|
"benchmarks/kernels/benchmark_machete.py",
|
||||||
|
"benchmarks/fused_kernels/layernorm_rms_benchmarks.py",
|
||||||
|
"benchmarks/cutlass_benchmarks/w8a8_benchmarks.py",
|
||||||
|
"benchmarks/cutlass_benchmarks/sparse_benchmarks.py",
|
||||||
|
# cloudpickle
|
||||||
|
"vllm/v1/executor/multiproc_executor.py",
|
||||||
|
"vllm/v1/executor/ray_executor.py",
|
||||||
|
"vllm/entrypoints/llm.py",
|
||||||
|
"tests/utils.py",
|
||||||
|
# pickle and cloudpickle
|
||||||
|
"vllm/v1/serial_utils.py",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"re": ForbiddenImport(
|
||||||
|
pattern=r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)",
|
||||||
|
tip="Replace 'import re' with 'import regex as re' or 'import regex'.",
|
||||||
|
allowed_pattern=re.compile(r"^\s*import\s+regex(\s*|\s+as\s+re\s*)$"),
|
||||||
|
allowed_files={"setup.py"},
|
||||||
|
),
|
||||||
|
"triton": ForbiddenImport(
|
||||||
|
pattern=r"^(from|import)\s+triton(\s|\.|$)",
|
||||||
|
tip="Use 'from vllm.triton_utils import triton' instead.",
|
||||||
|
allowed_pattern=re.compile(
|
||||||
|
"from vllm.triton_utils import (triton|tl|tl, triton)"
|
||||||
|
),
|
||||||
|
allowed_files={"vllm/triton_utils/importing.py"},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def check_file(path: str) -> int:
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
return_code = 0
|
||||||
|
# Check all patterns in the whole file
|
||||||
|
for import_name, forbidden_import in CHECK_IMPORTS.items():
|
||||||
|
# Skip files that are allowed for this import
|
||||||
|
if path in forbidden_import.allowed_files:
|
||||||
|
continue
|
||||||
|
# Search for forbidden imports
|
||||||
|
for match in re.finditer(forbidden_import.pattern, content, re.MULTILINE):
|
||||||
|
# Check if it's allowed
|
||||||
|
if forbidden_import.allowed_pattern.match(match.group()):
|
||||||
|
continue
|
||||||
|
# Calculate line number from match position
|
||||||
|
line_num = content[: match.start() + 1].count("\n") + 1
|
||||||
|
print(
|
||||||
|
f"{path}:{line_num}: "
|
||||||
|
"\033[91merror:\033[0m " # red color
|
||||||
|
f"Found forbidden import: {import_name}. {forbidden_import.tip}"
|
||||||
|
)
|
||||||
|
return_code = 1
|
||||||
|
return return_code
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
returncode = 0
|
||||||
|
for path in sys.argv[1:]:
|
||||||
|
returncode |= check_file(path)
|
||||||
|
return returncode
|
||||||
|
|
||||||
|
|
||||||
|
def test_regex():
|
||||||
|
test_cases = [
|
||||||
|
# Should match
|
||||||
|
("import pickle", True),
|
||||||
|
("import cloudpickle", True),
|
||||||
|
("import pickle as pkl", True),
|
||||||
|
("import cloudpickle as cpkl", True),
|
||||||
|
("from pickle import *", True),
|
||||||
|
("from cloudpickle import dumps", True),
|
||||||
|
("from pickle import dumps, loads", True),
|
||||||
|
("from cloudpickle import (dumps, loads)", True),
|
||||||
|
(" import pickle", True),
|
||||||
|
("\timport cloudpickle", True),
|
||||||
|
("from pickle import loads", True),
|
||||||
|
# Should not match
|
||||||
|
("import somethingelse", False),
|
||||||
|
("from somethingelse import pickle", False),
|
||||||
|
("# import pickle", False),
|
||||||
|
("print('import pickle')", False),
|
||||||
|
("import pickleas as asdf", False),
|
||||||
|
]
|
||||||
|
for i, (line, should_match) in enumerate(test_cases):
|
||||||
|
result = bool(CHECK_IMPORTS["pickle/cloudpickle"].pattern.match(line))
|
||||||
|
assert result == should_match, (
|
||||||
|
f"Test case {i} failed: '{line}' (expected {should_match}, got {result})"
|
||||||
|
)
|
||||||
|
print("All regex tests passed.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if "--test-regex" in sys.argv:
|
||||||
|
test_regex()
|
||||||
|
else:
|
||||||
|
sys.exit(main())
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import regex as re
|
|
||||||
|
|
||||||
# List of files (relative to repo root) that are allowed to import pickle or
|
|
||||||
# cloudpickle
|
|
||||||
#
|
|
||||||
# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST:
|
|
||||||
# The pickle and cloudpickle modules are known to be unsafe when deserializing
|
|
||||||
# data from potentially untrusted parties. They have resulted in multiple CVEs
|
|
||||||
# for vLLM and numerous vulnerabilities in the Python ecosystem more broadly.
|
|
||||||
# Before adding new uses of pickle/cloudpickle, please consider safer
|
|
||||||
# alternatives like msgpack or pydantic that are already in use in vLLM. Only
|
|
||||||
# add to this list if absolutely necessary and after careful security review.
|
|
||||||
ALLOWED_FILES = {
|
|
||||||
# pickle
|
|
||||||
"vllm/multimodal/hasher.py",
|
|
||||||
"vllm/transformers_utils/config.py",
|
|
||||||
"vllm/model_executor/models/registry.py",
|
|
||||||
"vllm/compilation/caching.py",
|
|
||||||
"vllm/compilation/piecewise_backend.py",
|
|
||||||
"vllm/distributed/utils.py",
|
|
||||||
"vllm/distributed/parallel_state.py",
|
|
||||||
"vllm/distributed/device_communicators/all_reduce_utils.py",
|
|
||||||
"vllm/distributed/device_communicators/shm_broadcast.py",
|
|
||||||
"vllm/distributed/device_communicators/shm_object_storage.py",
|
|
||||||
"vllm/utils/hashing.py",
|
|
||||||
"tests/multimodal/media/test_base.py",
|
|
||||||
"tests/tokenizers_/test_hf.py",
|
|
||||||
"tests/utils_/test_hashing.py",
|
|
||||||
"tests/compile/test_aot_compile.py",
|
|
||||||
"benchmarks/kernels/graph_machete_bench.py",
|
|
||||||
"benchmarks/kernels/benchmark_lora.py",
|
|
||||||
"benchmarks/kernels/benchmark_machete.py",
|
|
||||||
"benchmarks/fused_kernels/layernorm_rms_benchmarks.py",
|
|
||||||
"benchmarks/cutlass_benchmarks/w8a8_benchmarks.py",
|
|
||||||
"benchmarks/cutlass_benchmarks/sparse_benchmarks.py",
|
|
||||||
# cloudpickle
|
|
||||||
"vllm/v1/executor/multiproc_executor.py",
|
|
||||||
"vllm/v1/executor/ray_executor.py",
|
|
||||||
"vllm/entrypoints/llm.py",
|
|
||||||
"tests/utils.py",
|
|
||||||
# pickle and cloudpickle
|
|
||||||
"vllm/v1/serial_utils.py",
|
|
||||||
}
|
|
||||||
|
|
||||||
PICKLE_RE = re.compile(
|
|
||||||
r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
|
|
||||||
r"|from\s+(pickle|cloudpickle)\s+import\b)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def scan_file(path: str) -> int:
|
|
||||||
with open(path, encoding="utf-8") as f:
|
|
||||||
for i, line in enumerate(f, 1):
|
|
||||||
if PICKLE_RE.match(line):
|
|
||||||
print(
|
|
||||||
f"{path}:{i}: "
|
|
||||||
"\033[91merror:\033[0m " # red color
|
|
||||||
"Found pickle/cloudpickle import"
|
|
||||||
)
|
|
||||||
return 1
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
returncode = 0
|
|
||||||
for filename in sys.argv[1:]:
|
|
||||||
if filename in ALLOWED_FILES:
|
|
||||||
continue
|
|
||||||
returncode |= scan_file(filename)
|
|
||||||
return returncode
|
|
||||||
|
|
||||||
|
|
||||||
def test_regex():
|
|
||||||
test_cases = [
|
|
||||||
# Should match
|
|
||||||
("import pickle", True),
|
|
||||||
("import cloudpickle", True),
|
|
||||||
("import pickle as pkl", True),
|
|
||||||
("import cloudpickle as cpkl", True),
|
|
||||||
("from pickle import *", True),
|
|
||||||
("from cloudpickle import dumps", True),
|
|
||||||
("from pickle import dumps, loads", True),
|
|
||||||
("from cloudpickle import (dumps, loads)", True),
|
|
||||||
(" import pickle", True),
|
|
||||||
("\timport cloudpickle", True),
|
|
||||||
("from pickle import loads", True),
|
|
||||||
# Should not match
|
|
||||||
("import somethingelse", False),
|
|
||||||
("from somethingelse import pickle", False),
|
|
||||||
("# import pickle", False),
|
|
||||||
("print('import pickle')", False),
|
|
||||||
("import pickleas as asdf", False),
|
|
||||||
]
|
|
||||||
for i, (line, should_match) in enumerate(test_cases):
|
|
||||||
result = bool(PICKLE_RE.match(line))
|
|
||||||
assert result == should_match, (
|
|
||||||
f"Test case {i} failed: '{line}' (expected {should_match}, got {result})"
|
|
||||||
)
|
|
||||||
print("All regex tests passed.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if "--test-regex" in sys.argv:
|
|
||||||
test_regex()
|
|
||||||
else:
|
|
||||||
sys.exit(main())
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import regex as re
|
|
||||||
|
|
||||||
FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)")
|
|
||||||
|
|
||||||
# the way allowed to import triton
|
|
||||||
ALLOWED_LINES = {
|
|
||||||
"from vllm.triton_utils import triton",
|
|
||||||
"from vllm.triton_utils import tl",
|
|
||||||
"from vllm.triton_utils import tl, triton",
|
|
||||||
}
|
|
||||||
|
|
||||||
ALLOWED_FILES = {"vllm/triton_utils/importing.py"}
|
|
||||||
|
|
||||||
|
|
||||||
def is_allowed_file(current_file: str) -> bool:
|
|
||||||
return current_file in ALLOWED_FILES
|
|
||||||
|
|
||||||
|
|
||||||
def is_forbidden_import(line: str) -> bool:
|
|
||||||
stripped = line.strip()
|
|
||||||
return bool(FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
|
|
||||||
|
|
||||||
|
|
||||||
def parse_diff(diff: str) -> list[str]:
|
|
||||||
violations = []
|
|
||||||
current_file = None
|
|
||||||
current_lineno = None
|
|
||||||
skip_allowed_file = False
|
|
||||||
|
|
||||||
for line in diff.splitlines():
|
|
||||||
if line.startswith("+++ b/"):
|
|
||||||
current_file = line[6:]
|
|
||||||
skip_allowed_file = is_allowed_file(current_file)
|
|
||||||
elif skip_allowed_file:
|
|
||||||
continue
|
|
||||||
elif line.startswith("@@"):
|
|
||||||
match = re.search(r"\+(\d+)", line)
|
|
||||||
if match:
|
|
||||||
current_lineno = int(match.group(1)) - 1 # next "+ line" is here
|
|
||||||
elif line.startswith("+") and not line.startswith("++"):
|
|
||||||
current_lineno += 1
|
|
||||||
code_line = line[1:]
|
|
||||||
if is_forbidden_import(code_line):
|
|
||||||
violations.append(
|
|
||||||
f"{current_file}:{current_lineno}: {code_line.strip()}"
|
|
||||||
)
|
|
||||||
return violations
|
|
||||||
|
|
||||||
|
|
||||||
def get_diff(diff_type: str) -> str:
|
|
||||||
if diff_type == "staged":
|
|
||||||
return subprocess.check_output(
|
|
||||||
["git", "diff", "--cached", "--unified=0"], text=True
|
|
||||||
)
|
|
||||||
elif diff_type == "unstaged":
|
|
||||||
return subprocess.check_output(["git", "diff", "--unified=0"], text=True)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown diff_type: {diff_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
all_violations = []
|
|
||||||
for diff_type in ["staged", "unstaged"]:
|
|
||||||
try:
|
|
||||||
diff_output = get_diff(diff_type)
|
|
||||||
violations = parse_diff(diff_output)
|
|
||||||
all_violations.extend(violations)
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)
|
|
||||||
|
|
||||||
if all_violations:
|
|
||||||
print(
|
|
||||||
"❌ Forbidden direct `import triton` detected."
|
|
||||||
" ➤ Use `from vllm.triton_utils import triton` instead.\n"
|
|
||||||
)
|
|
||||||
for v in all_violations:
|
|
||||||
print(f"❌ {v}")
|
|
||||||
return 1
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
sys.exit(main())
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import regex as re
|
|
||||||
|
|
||||||
FORBIDDEN_PATTERNS = re.compile(r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)")
|
|
||||||
ALLOWED_PATTERNS = [
|
|
||||||
re.compile(r"^\s*import\s+regex\s+as\s+re\s*$"),
|
|
||||||
re.compile(r"^\s*import\s+regex\s*$"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_staged_python_files() -> list[str]:
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
["git", "diff", "--cached", "--name-only", "--diff-filter=AM"],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
files = result.stdout.strip().split("\n") if result.stdout.strip() else []
|
|
||||||
return [f for f in files if f.endswith(".py")]
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def is_forbidden_import(line: str) -> bool:
|
|
||||||
line = line.strip()
|
|
||||||
return bool(
|
|
||||||
FORBIDDEN_PATTERNS.match(line)
|
|
||||||
and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_file(filepath: str) -> list[tuple[int, str]]:
|
|
||||||
violations = []
|
|
||||||
try:
|
|
||||||
with open(filepath, encoding="utf-8") as f:
|
|
||||||
for line_num, line in enumerate(f, 1):
|
|
||||||
if is_forbidden_import(line):
|
|
||||||
violations.append((line_num, line.strip()))
|
|
||||||
except (OSError, UnicodeDecodeError):
|
|
||||||
pass
|
|
||||||
return violations
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
|
||||||
files = get_staged_python_files()
|
|
||||||
if not files:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
total_violations = 0
|
|
||||||
|
|
||||||
for filepath in files:
|
|
||||||
if not Path(filepath).exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
if filepath == "setup.py":
|
|
||||||
continue
|
|
||||||
|
|
||||||
violations = check_file(filepath)
|
|
||||||
if violations:
|
|
||||||
print(f"\n❌ {filepath}:")
|
|
||||||
for line_num, line in violations:
|
|
||||||
print(f" Line {line_num}: {line}")
|
|
||||||
total_violations += 1
|
|
||||||
|
|
||||||
if total_violations > 0:
|
|
||||||
print(f"\n💡 Found {total_violations} violation(s).")
|
|
||||||
print("❌ Please replace 'import re' with 'import regex as re'")
|
|
||||||
print(" Also replace 'from re import ...' with 'from regex import ...'") # noqa: E501
|
|
||||||
print("✅ Allowed imports:")
|
|
||||||
print(" - import regex as re")
|
|
||||||
print(" - import regex") # noqa: E501
|
|
||||||
return 1
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
raise SystemExit(main())
|
|
||||||
Reference in New Issue
Block a user