diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc93447cf..db7321b93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -121,24 +121,9 @@ repos: name: Update Dockerfile dependency graph entry: tools/pre_commit/update-dockerfile-graph.sh language: script - - id: enforce-import-regex-instead-of-re - name: Enforce import regex as re - entry: python tools/pre_commit/enforce_regex_import.py - language: python - types: [python] - pass_filenames: false - additional_dependencies: [regex] - # forbid directly import triton - - id: forbid-direct-triton-import - name: "Forbid direct 'import triton'" - entry: python tools/pre_commit/check_triton_import.py - language: python - types: [python] - pass_filenames: false - additional_dependencies: [regex] - - id: check-pickle-imports - name: Prevent new pickle/cloudpickle imports - entry: python tools/pre_commit/check_pickle_imports.py + - id: check-forbidden-imports + name: Check for forbidden imports + entry: python tools/pre_commit/check_forbidden_imports.py language: python types: [python] additional_dependencies: [regex] diff --git a/tools/pre_commit/check_forbidden_imports.py b/tools/pre_commit/check_forbidden_imports.py new file mode 100644 index 000000000..009e9bcbc --- /dev/null +++ b/tools/pre_commit/check_forbidden_imports.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import sys +from dataclasses import dataclass, field + +import regex as re + + +@dataclass +class ForbiddenImport: + pattern: str + tip: str + allowed_pattern: re.Pattern = re.compile(r"^$") # matches nothing by default + allowed_files: set[str] = field(default_factory=set) + + +CHECK_IMPORTS = { + "pickle/cloudpickle": ForbiddenImport( + pattern=( + r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)" + r"|from\s+(pickle|cloudpickle)\s+import\b)" + ), + tip=( + "Avoid using pickle or cloudpickle or add this file to " + "tools/pre_commit/check_forbidden_imports.py." + ), + allowed_files={ + # pickle + "vllm/multimodal/hasher.py", + "vllm/transformers_utils/config.py", + "vllm/model_executor/models/registry.py", + "vllm/compilation/caching.py", + "vllm/compilation/piecewise_backend.py", + "vllm/distributed/utils.py", + "vllm/distributed/parallel_state.py", + "vllm/distributed/device_communicators/all_reduce_utils.py", + "vllm/distributed/device_communicators/shm_broadcast.py", + "vllm/distributed/device_communicators/shm_object_storage.py", + "vllm/utils/hashing.py", + "tests/multimodal/media/test_base.py", + "tests/tokenizers_/test_hf.py", + "tests/utils_/test_hashing.py", + "tests/compile/test_aot_compile.py", + "benchmarks/kernels/graph_machete_bench.py", + "benchmarks/kernels/benchmark_lora.py", + "benchmarks/kernels/benchmark_machete.py", + "benchmarks/fused_kernels/layernorm_rms_benchmarks.py", + "benchmarks/cutlass_benchmarks/w8a8_benchmarks.py", + "benchmarks/cutlass_benchmarks/sparse_benchmarks.py", + # cloudpickle + "vllm/v1/executor/multiproc_executor.py", + "vllm/v1/executor/ray_executor.py", + "vllm/entrypoints/llm.py", + "tests/utils.py", + # pickle and cloudpickle + "vllm/v1/serial_utils.py", + }, + ), + "re": ForbiddenImport( + pattern=r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)", + tip="Replace 'import re' with 'import regex as re' or 'import regex'.", + allowed_pattern=re.compile(r"^\s*import\s+regex(\s*|\s+as\s+re\s*)$"), + allowed_files={"setup.py"}, + ), + "triton": ForbiddenImport( + pattern=r"^(from|import)\s+triton(\s|\.|$)", + tip="Use 'from vllm.triton_utils import triton' instead.", + allowed_pattern=re.compile( + "from vllm.triton_utils import (triton|tl|tl, triton)" + ), + allowed_files={"vllm/triton_utils/importing.py"}, + ), +} + + +def check_file(path: str) -> int: + with open(path, encoding="utf-8") as f: + content = f.read() + return_code = 0 + # Check all patterns in the whole file + for import_name, forbidden_import in CHECK_IMPORTS.items(): + # Skip files that are allowed for this import + if path in forbidden_import.allowed_files: + continue + # Search for forbidden imports + for match in re.finditer(forbidden_import.pattern, content, re.MULTILINE): + # Check if it's allowed + if forbidden_import.allowed_pattern.match(match.group()): + continue + # Calculate line number from match position + line_num = content[: match.start() + 1].count("\n") + 1 + print( + f"{path}:{line_num}: " + "\033[91merror:\033[0m " # red color + f"Found forbidden import: {import_name}. {forbidden_import.tip}" + ) + return_code = 1 + return return_code + + +def main(): + returncode = 0 + for path in sys.argv[1:]: + returncode |= check_file(path) + return returncode + + +def test_regex(): + test_cases = [ + # Should match + ("import pickle", True), + ("import cloudpickle", True), + ("import pickle as pkl", True), + ("import cloudpickle as cpkl", True), + ("from pickle import *", True), + ("from cloudpickle import dumps", True), + ("from pickle import dumps, loads", True), + ("from cloudpickle import (dumps, loads)", True), + (" import pickle", True), + ("\timport cloudpickle", True), + ("from pickle import loads", True), + # Should not match + ("import somethingelse", False), + ("from somethingelse import pickle", False), + ("# import pickle", False), + ("print('import pickle')", False), + ("import pickleas as asdf", False), + ] + for i, (line, should_match) in enumerate(test_cases): + result = bool(CHECK_IMPORTS["pickle/cloudpickle"].pattern.match(line)) + assert result == should_match, ( + f"Test case {i} failed: '{line}' (expected {should_match}, got {result})" + ) + print("All regex tests passed.") + + +if __name__ == "__main__": + if "--test-regex" in sys.argv: + test_regex() + else: + sys.exit(main()) diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py deleted file mode 100644 index 9b97d269d..000000000 --- a/tools/pre_commit/check_pickle_imports.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import sys - -import regex as re - -# List of files (relative to repo root) that are allowed to import pickle or -# cloudpickle -# -# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST: -# The pickle and cloudpickle modules are known to be unsafe when deserializing -# data from potentially untrusted parties. They have resulted in multiple CVEs -# for vLLM and numerous vulnerabilities in the Python ecosystem more broadly. -# Before adding new uses of pickle/cloudpickle, please consider safer -# alternatives like msgpack or pydantic that are already in use in vLLM. Only -# add to this list if absolutely necessary and after careful security review. -ALLOWED_FILES = { - # pickle - "vllm/multimodal/hasher.py", - "vllm/transformers_utils/config.py", - "vllm/model_executor/models/registry.py", - "vllm/compilation/caching.py", - "vllm/compilation/piecewise_backend.py", - "vllm/distributed/utils.py", - "vllm/distributed/parallel_state.py", - "vllm/distributed/device_communicators/all_reduce_utils.py", - "vllm/distributed/device_communicators/shm_broadcast.py", - "vllm/distributed/device_communicators/shm_object_storage.py", - "vllm/utils/hashing.py", - "tests/multimodal/media/test_base.py", - "tests/tokenizers_/test_hf.py", - "tests/utils_/test_hashing.py", - "tests/compile/test_aot_compile.py", - "benchmarks/kernels/graph_machete_bench.py", - "benchmarks/kernels/benchmark_lora.py", - "benchmarks/kernels/benchmark_machete.py", - "benchmarks/fused_kernels/layernorm_rms_benchmarks.py", - "benchmarks/cutlass_benchmarks/w8a8_benchmarks.py", - "benchmarks/cutlass_benchmarks/sparse_benchmarks.py", - # cloudpickle - "vllm/v1/executor/multiproc_executor.py", - "vllm/v1/executor/ray_executor.py", - "vllm/entrypoints/llm.py", - "tests/utils.py", - # pickle and cloudpickle - "vllm/v1/serial_utils.py", -} - -PICKLE_RE = re.compile( - r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)" - r"|from\s+(pickle|cloudpickle)\s+import\b)" -) - - -def scan_file(path: str) -> int: - with open(path, encoding="utf-8") as f: - for i, line in enumerate(f, 1): - if PICKLE_RE.match(line): - print( - f"{path}:{i}: " - "\033[91merror:\033[0m " # red color - "Found pickle/cloudpickle import" - ) - return 1 - return 0 - - -def main(): - returncode = 0 - for filename in sys.argv[1:]: - if filename in ALLOWED_FILES: - continue - returncode |= scan_file(filename) - return returncode - - -def test_regex(): - test_cases = [ - # Should match - ("import pickle", True), - ("import cloudpickle", True), - ("import pickle as pkl", True), - ("import cloudpickle as cpkl", True), - ("from pickle import *", True), - ("from cloudpickle import dumps", True), - ("from pickle import dumps, loads", True), - ("from cloudpickle import (dumps, loads)", True), - (" import pickle", True), - ("\timport cloudpickle", True), - ("from pickle import loads", True), - # Should not match - ("import somethingelse", False), - ("from somethingelse import pickle", False), - ("# import pickle", False), - ("print('import pickle')", False), - ("import pickleas as asdf", False), - ] - for i, (line, should_match) in enumerate(test_cases): - result = bool(PICKLE_RE.match(line)) - assert result == should_match, ( - f"Test case {i} failed: '{line}' (expected {should_match}, got {result})" - ) - print("All regex tests passed.") - - -if __name__ == "__main__": - if "--test-regex" in sys.argv: - test_regex() - else: - sys.exit(main()) diff --git a/tools/pre_commit/check_triton_import.py b/tools/pre_commit/check_triton_import.py deleted file mode 100644 index 1b83074fe..000000000 --- a/tools/pre_commit/check_triton_import.py +++ /dev/null @@ -1,88 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import subprocess -import sys - -import regex as re - -FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)") - -# the way allowed to import triton -ALLOWED_LINES = { - "from vllm.triton_utils import triton", - "from vllm.triton_utils import tl", - "from vllm.triton_utils import tl, triton", -} - -ALLOWED_FILES = {"vllm/triton_utils/importing.py"} - - -def is_allowed_file(current_file: str) -> bool: - return current_file in ALLOWED_FILES - - -def is_forbidden_import(line: str) -> bool: - stripped = line.strip() - return bool(FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES - - -def parse_diff(diff: str) -> list[str]: - violations = [] - current_file = None - current_lineno = None - skip_allowed_file = False - - for line in diff.splitlines(): - if line.startswith("+++ b/"): - current_file = line[6:] - skip_allowed_file = is_allowed_file(current_file) - elif skip_allowed_file: - continue - elif line.startswith("@@"): - match = re.search(r"\+(\d+)", line) - if match: - current_lineno = int(match.group(1)) - 1 # next "+ line" is here - elif line.startswith("+") and not line.startswith("++"): - current_lineno += 1 - code_line = line[1:] - if is_forbidden_import(code_line): - violations.append( - f"{current_file}:{current_lineno}: {code_line.strip()}" - ) - return violations - - -def get_diff(diff_type: str) -> str: - if diff_type == "staged": - return subprocess.check_output( - ["git", "diff", "--cached", "--unified=0"], text=True - ) - elif diff_type == "unstaged": - return subprocess.check_output(["git", "diff", "--unified=0"], text=True) - else: - raise ValueError(f"Unknown diff_type: {diff_type}") - - -def main(): - all_violations = [] - for diff_type in ["staged", "unstaged"]: - try: - diff_output = get_diff(diff_type) - violations = parse_diff(diff_output) - all_violations.extend(violations) - except subprocess.CalledProcessError as e: - print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr) - - if all_violations: - print( - "āŒ Forbidden direct `import triton` detected." - " āž¤ Use `from vllm.triton_utils import triton` instead.\n" - ) - for v in all_violations: - print(f"āŒ {v}") - return 1 - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/tools/pre_commit/enforce_regex_import.py b/tools/pre_commit/enforce_regex_import.py deleted file mode 100644 index a29952e92..000000000 --- a/tools/pre_commit/enforce_regex_import.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import subprocess -from pathlib import Path - -import regex as re - -FORBIDDEN_PATTERNS = re.compile(r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)") -ALLOWED_PATTERNS = [ - re.compile(r"^\s*import\s+regex\s+as\s+re\s*$"), - re.compile(r"^\s*import\s+regex\s*$"), -] - - -def get_staged_python_files() -> list[str]: - try: - result = subprocess.run( - ["git", "diff", "--cached", "--name-only", "--diff-filter=AM"], - capture_output=True, - text=True, - check=True, - ) - files = result.stdout.strip().split("\n") if result.stdout.strip() else [] - return [f for f in files if f.endswith(".py")] - except subprocess.CalledProcessError: - return [] - - -def is_forbidden_import(line: str) -> bool: - line = line.strip() - return bool( - FORBIDDEN_PATTERNS.match(line) - and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS) - ) - - -def check_file(filepath: str) -> list[tuple[int, str]]: - violations = [] - try: - with open(filepath, encoding="utf-8") as f: - for line_num, line in enumerate(f, 1): - if is_forbidden_import(line): - violations.append((line_num, line.strip())) - except (OSError, UnicodeDecodeError): - pass - return violations - - -def main() -> int: - files = get_staged_python_files() - if not files: - return 0 - - total_violations = 0 - - for filepath in files: - if not Path(filepath).exists(): - continue - - if filepath == "setup.py": - continue - - violations = check_file(filepath) - if violations: - print(f"\nāŒ {filepath}:") - for line_num, line in violations: - print(f" Line {line_num}: {line}") - total_violations += 1 - - if total_violations > 0: - print(f"\nšŸ’” Found {total_violations} violation(s).") - print("āŒ Please replace 'import re' with 'import regex as re'") - print(" Also replace 'from re import ...' with 'from regex import ...'") # noqa: E501 - print("āœ… Allowed imports:") - print(" - import regex as re") - print(" - import regex") # noqa: E501 - return 1 - - return 0 - - -if __name__ == "__main__": - raise SystemExit(main())