[Hardware] Replace torch.cuda.device_count/current_device/set_device API (#36145)
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -134,7 +134,7 @@ class CoreEngineProcManager:
|
||||
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
|
||||
# Adjust device control in DP for non-CUDA platforms
|
||||
# as well as external and ray launchers
|
||||
# For CUDA platforms, we use torch.cuda.set_device()
|
||||
# For CUDA platforms, we use torch.accelerator.set_device_index()()
|
||||
if is_dp and (
|
||||
not current_platform.is_cuda_alike()
|
||||
or vllm_config.parallel_config.use_ray
|
||||
|
||||
@@ -73,8 +73,8 @@ class SMControlContextManager:
|
||||
assert current_platform.is_cuda(), (
|
||||
"SM control is currently only supported on CUDA"
|
||||
)
|
||||
|
||||
total_sms = num_compute_units(torch.cuda.current_device())
|
||||
device = torch.accelerator.current_device_index()
|
||||
total_sms = num_compute_units(device)
|
||||
|
||||
assert comm_sms < total_sms
|
||||
self.total_sms = total_sms
|
||||
@@ -204,7 +204,7 @@ class UBatchWrapper:
|
||||
|
||||
@torch.inference_mode()
|
||||
def _capture_ubatch_thread(results, ubatch_metadata):
|
||||
torch.cuda.set_device(self.device)
|
||||
torch.accelerator.set_device_index(self.device)
|
||||
ubatch_context = ubatch_metadata.context
|
||||
with torch.cuda.stream(ubatch_context.compute_stream):
|
||||
_ = torch.cuda.current_blas_handle()
|
||||
|
||||
@@ -239,11 +239,11 @@ class Worker(WorkerBase):
|
||||
|
||||
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
|
||||
self.local_rank += dp_local_rank * tp_pp_world_size
|
||||
assert self.local_rank < torch.cuda.device_count(), (
|
||||
assert self.local_rank < torch.accelerator.device_count(), (
|
||||
f"DP adjusted local rank {self.local_rank} is out of bounds. "
|
||||
)
|
||||
visible_device_count = (
|
||||
torch.cuda.device_count() if torch.cuda.is_available() else 0
|
||||
torch.accelerator.device_count() if torch.cuda.is_available() else 0
|
||||
)
|
||||
assert self.parallel_config.local_world_size <= visible_device_count, (
|
||||
f"local_world_size ({self.parallel_config.local_world_size}) must "
|
||||
@@ -252,7 +252,7 @@ class Worker(WorkerBase):
|
||||
)
|
||||
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
torch.accelerator.set_device_index(self.device)
|
||||
|
||||
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class XPUWorker(Worker):
|
||||
and current_platform.is_xpu()
|
||||
):
|
||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
torch.accelerator.set_device_index(self.device)
|
||||
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||
torch.accelerator.empty_cache()
|
||||
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||
|
||||
Reference in New Issue
Block a user