Files
vllm/tools/pre_commit/check_torch_cuda.py
wliao2 4dfad17ed1 replace cuda_device_count_stateless() to current_platform.device_count() (#37841)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
Signed-off-by: wliao2 <wei.liao@intel.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
2026-03-31 22:32:54 +08:00

49 lines
1.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
import regex as re
# --------------------------------------------------------------------------- #
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
# --------------------------------------------------------------------------- #
_TORCH_CUDA_PATTERNS = [
r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|memory_reserved|memory_allocated|max_memory_allocated|max_memory_reserved|reset_peak_memory_stats|memory_stats|set_device|device\()\b",
r"\bwith\storch\.cuda\.device\b",
# Calls torch.cuda.{_is_compiled/_device_count_amdsmi/_device_count_nvml} internally
r"\bcuda_device_count_stateless\(\)\b",
]
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
def scan_file(path: str) -> int:
with open(path, encoding="utf-8") as f:
content = f.read()
for pattern in _TORCH_CUDA_PATTERNS:
for match in re.finditer(pattern, content, re.MULTILINE):
# Calculate line number from match position
line_num = content[: match.start() + 1].count("\n") + 1
print(
f"{path}:{line_num}: "
"\033[91merror:\033[0m " # red color
"Found torch.cuda API call. Please refer RFC "
"https://github.com/vllm-project/vllm/issues/30679, use "
"torch.accelerator API instead."
)
return 1
return 0
def main():
returncode = 0
for filename in sys.argv[1:]:
if any(filename.startswith(prefix) for prefix in ALLOWED_FILES):
continue
returncode |= scan_file(filename)
return returncode
if __name__ == "__main__":
sys.exit(main())