[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -8,8 +8,8 @@ import regex as re
|
||||
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
|
||||
# --------------------------------------------------------------------------- #
|
||||
_TORCH_CUDA_PATTERNS = [
|
||||
r"\btorch\.cuda\.(empty_cache|synchronize|device\()\b",
|
||||
r"\bwith\btorch\.cuda\.device\b",
|
||||
r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|set_device|device\()\b",
|
||||
r"\bwith\storch\.cuda\.device\b",
|
||||
]
|
||||
|
||||
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
|
||||
@@ -25,7 +25,9 @@ def scan_file(path: str) -> int:
|
||||
print(
|
||||
f"{path}:{line_num}: "
|
||||
"\033[91merror:\033[0m " # red color
|
||||
"Found torch.cuda API call"
|
||||
"Found torch.cuda API call. Please refer RFC "
|
||||
"https://github.com/vllm-project/vllm/issues/30679, use "
|
||||
"torch.accelerator API instead."
|
||||
)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
Reference in New Issue
Block a user