replace with torch.cuda.device with with torch.accelerator.device_index (#36144)

Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma
2026-03-12 14:12:57 +08:00
committed by GitHub
parent 584a3f56de
commit 894843eb25
10 changed files with 17 additions and 15 deletions

View File

@@ -626,7 +626,11 @@ class BenchmarkWorker:
if visible_device != f"{self.device_id}":
need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
with (
torch.accelerator.device_index(self.device_id)
if need_device_guard
else nullcontext()
):
for idx, config in enumerate(tqdm(search_space)):
try:
kernel_time = benchmark_config(