[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)

Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-03-12 22:57:47 +08:00
committed by GitHub
parent 2e693f48e7
commit 53ec16a705
89 changed files with 254 additions and 219 deletions

View File

@@ -134,7 +134,7 @@ class CoreEngineProcManager:
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
# Adjust device control in DP for non-CUDA platforms
# as well as external and ray launchers
# For CUDA platforms, we use torch.cuda.set_device()
# For CUDA platforms, we use torch.accelerator.set_device_index()()
if is_dp and (
not current_platform.is_cuda_alike()
or vllm_config.parallel_config.use_ray

View File

@@ -73,8 +73,8 @@ class SMControlContextManager:
assert current_platform.is_cuda(), (
"SM control is currently only supported on CUDA"
)
total_sms = num_compute_units(torch.cuda.current_device())
device = torch.accelerator.current_device_index()
total_sms = num_compute_units(device)
assert comm_sms < total_sms
self.total_sms = total_sms
@@ -204,7 +204,7 @@ class UBatchWrapper:
@torch.inference_mode()
def _capture_ubatch_thread(results, ubatch_metadata):
torch.cuda.set_device(self.device)
torch.accelerator.set_device_index(self.device)
ubatch_context = ubatch_metadata.context
with torch.cuda.stream(ubatch_context.compute_stream):
_ = torch.cuda.current_blas_handle()

View File

@@ -239,11 +239,11 @@ class Worker(WorkerBase):
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
self.local_rank += dp_local_rank * tp_pp_world_size
assert self.local_rank < torch.cuda.device_count(), (
assert self.local_rank < torch.accelerator.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
visible_device_count = (
torch.cuda.device_count() if torch.cuda.is_available() else 0
torch.accelerator.device_count() if torch.cuda.is_available() else 0
)
assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must "
@@ -252,7 +252,7 @@ class Worker(WorkerBase):
)
self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device)
torch.accelerator.set_device_index(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)

View File

@@ -60,7 +60,7 @@ class XPUWorker(Worker):
and current_platform.is_xpu()
):
self.device = torch.device(f"xpu:{self.local_rank}")
current_platform.set_device(self.device)
torch.accelerator.set_device_index(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)
torch.accelerator.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(