[Hardware] Replace torch.cuda.empty_cache with torch.accelerator.empty_cache (#30681)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Kunshang Ji <jikunshang95@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
43
tools/pre_commit/check_torch_cuda.py
Normal file
43
tools/pre_commit/check_torch_cuda.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import sys
|
||||
|
||||
import regex as re
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
|
||||
# --------------------------------------------------------------------------- #
|
||||
_TORCH_CUDA_PATTERNS = [
|
||||
r"\btorch\.cuda\.empty_cache\b",
|
||||
]
|
||||
|
||||
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
|
||||
|
||||
|
||||
def scan_file(path: str) -> int:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
for pattern in _TORCH_CUDA_PATTERNS:
|
||||
for match in re.finditer(pattern, content, re.MULTILINE):
|
||||
# Calculate line number from match position
|
||||
line_num = content[: match.start() + 1].count("\n") + 1
|
||||
print(
|
||||
f"{path}:{line_num}: "
|
||||
"\033[91merror:\033[0m " # red color
|
||||
"Found torch.cuda API call"
|
||||
)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def main():
|
||||
returncode = 0
|
||||
for filename in sys.argv[1:]:
|
||||
if any(filename.startswith(prefix) for prefix in ALLOWED_FILES):
|
||||
continue
|
||||
returncode |= scan_file(filename)
|
||||
return returncode
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user