# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import sys import regex as re # --------------------------------------------------------------------------- # # Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx` # --------------------------------------------------------------------------- # _TORCH_CUDA_PATTERNS = [ r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|memory_reserved|memory_allocated|max_memory_allocated|max_memory_reserved|reset_peak_memory_stats|memory_stats|set_device|device\()\b", r"\bwith\storch\.cuda\.device\b", # Calls torch.cuda.{_is_compiled/_device_count_amdsmi/_device_count_nvml} internally r"\bcuda_device_count_stateless\(\)\b", ] ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"} def scan_file(path: str) -> int: with open(path, encoding="utf-8") as f: content = f.read() for pattern in _TORCH_CUDA_PATTERNS: for match in re.finditer(pattern, content, re.MULTILINE): # Calculate line number from match position line_num = content[: match.start() + 1].count("\n") + 1 print( f"{path}:{line_num}: " "\033[91merror:\033[0m " # red color "Found torch.cuda API call. Please refer RFC " "https://github.com/vllm-project/vllm/issues/30679, use " "torch.accelerator API instead." ) return 1 return 0 def main(): returncode = 0 for filename in sys.argv[1:]: if any(filename.startswith(prefix) for prefix in ALLOWED_FILES): continue returncode |= scan_file(filename) return returncode if __name__ == "__main__": sys.exit(main())