replace with torch.cuda.device with with torch.accelerator.device_index (#36144)

Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma
2026-03-12 14:12:57 +08:00
committed by GitHub
parent 584a3f56de
commit 894843eb25
10 changed files with 17 additions and 15 deletions

View File

@@ -626,7 +626,11 @@ class BenchmarkWorker:
if visible_device != f"{self.device_id}": if visible_device != f"{self.device_id}":
need_device_guard = True need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): with (
torch.accelerator.device_index(self.device_id)
if need_device_guard
else nullcontext()
):
for idx, config in enumerate(tqdm(search_space)): for idx, config in enumerate(tqdm(search_space)):
try: try:
kernel_time = benchmark_config( kernel_time = benchmark_config(

View File

@@ -8,8 +8,8 @@ import regex as re
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx` # Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
# --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- #
_TORCH_CUDA_PATTERNS = [ _TORCH_CUDA_PATTERNS = [
r"\btorch\.cuda\.empty_cache\b", r"\btorch\.cuda\.(empty_cache|synchronize|device\()\b",
r"\btorch\.cuda\.synchronize\b", r"\bwith\btorch\.cuda\.device\b",
] ]
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"} ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}

View File

@@ -133,9 +133,7 @@ class PyNcclCommunicator:
assert isinstance(device, torch.device) assert isinstance(device, torch.device)
self.device = device self.device = device
# nccl communicator and stream will use this device # nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the with torch.accelerator.device_index(device.index):
# current cuda device to the specified one
with torch.cuda.device(device):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank self.world_size, self.unique_id, self.rank
) )

View File

@@ -218,7 +218,7 @@ class P2pNcclEngine:
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
with torch.cuda.device(self.device): with torch.accelerator.device_index(self.device.index):
rank = 0 rank = 0
with set_p2p_nccl_context(self.nccl_num_channels): with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(2, unique_id, rank) comm: ncclComm_t = self.nccl.ncclCommInitRank(2, unique_id, rank)
@@ -377,7 +377,7 @@ class P2pNcclEngine:
data = msgpack.loads(message) data = msgpack.loads(message)
if data["cmd"] == "NEW": if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(bytes(data["unique_id"])) unique_id = self.nccl.unique_id_from_bytes(bytes(data["unique_id"]))
with torch.cuda.device(self.device): with torch.accelerator.device_index(self.device.index):
rank = 1 rank = 1
with set_p2p_nccl_context(self.nccl_num_channels): with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank( comm: ncclComm_t = self.nccl.ncclCommInitRank(

View File

@@ -105,7 +105,7 @@ def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
break break
if tensor is not None: if tensor is not None:
ctx = torch.cuda.device(tensor.device.index) ctx = torch.accelerator.device_index(tensor.device.index)
else: else:
ctx = contextlib.nullcontext() ctx = contextlib.nullcontext()

View File

@@ -119,7 +119,7 @@ def _layer_norm_fwd(
# heuristics for number of warps # heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8) num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups) grid = (M, ngroups)
with torch.cuda.device(x.device.index): with torch.accelerator.device_index(x.device.index):
_layer_norm_fwd_1pass_kernel[grid]( _layer_norm_fwd_1pass_kernel[grid](
x, x,
out, out,

View File

@@ -419,7 +419,7 @@ def selective_state_update(
and dt.stride(-1) == 0 and dt.stride(-1) == 0
and dt_bias.stride(-1) == 0 and dt_bias.stride(-1) == 0
) )
with torch.cuda.device(x.device.index): with torch.accelerator.device_index(x.device.index):
_selective_scan_update_kernel[grid]( _selective_scan_update_kernel[grid](
state, state,
x, x,

View File

@@ -185,7 +185,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtyp
* triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
nchunks * ngroups, nchunks * ngroups,
) )
with torch.cuda.device(a.device.index): with torch.accelerator.device_index(a.device.index):
_bmm_chunk_fwd_kernel[grid]( _bmm_chunk_fwd_kernel[grid](
a_ptr=a, a_ptr=a,
b_ptr=b, b_ptr=b,

View File

@@ -323,7 +323,7 @@ def _chunk_cumsum_fwd(
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
) )
grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"])) grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"]))
with torch.cuda.device(dt.device.index): with torch.accelerator.device_index(dt.device.index):
_chunk_cumsum_fwd_kernel[grid_chunk_cs]( _chunk_cumsum_fwd_kernel[grid_chunk_cs](
dt_ptr=dt, dt_ptr=dt,
A_ptr=A, A_ptr=A,
@@ -378,7 +378,7 @@ def _chunk_state_fwd(
nchunks, nchunks,
nheads, nheads,
) )
with torch.cuda.device(x.device.index): with torch.accelerator.device_index(x.device.index):
_chunk_state_fwd_kernel[grid]( _chunk_state_fwd_kernel[grid](
x_ptr=x, x_ptr=x,
b_ptr=B, b_ptr=B,

View File

@@ -120,7 +120,7 @@ def _state_passing_fwd(
) )
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads) grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads)
with torch.cuda.device(states.device.index): with torch.accelerator.device_index(states.device.index):
_state_passing_fwd_kernel[grid]( _state_passing_fwd_kernel[grid](
states_ptr=states, states_ptr=states,
out_ptr=out, out_ptr=out,