replace with torch.cuda.device with with torch.accelerator.device_index (#36144)
Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
@@ -626,7 +626,11 @@ class BenchmarkWorker:
|
||||
if visible_device != f"{self.device_id}":
|
||||
need_device_guard = True
|
||||
|
||||
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
|
||||
with (
|
||||
torch.accelerator.device_index(self.device_id)
|
||||
if need_device_guard
|
||||
else nullcontext()
|
||||
):
|
||||
for idx, config in enumerate(tqdm(search_space)):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
|
||||
Reference in New Issue
Block a user