replace with torch.cuda.device with with torch.accelerator.device_index (#36144)

Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma
2026-03-12 14:12:57 +08:00
committed by GitHub
parent 584a3f56de
commit 894843eb25
10 changed files with 17 additions and 15 deletions

View File

@@ -626,7 +626,11 @@ class BenchmarkWorker:
if visible_device != f"{self.device_id}":
need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
with (
torch.accelerator.device_index(self.device_id)
if need_device_guard
else nullcontext()
):
for idx, config in enumerate(tqdm(search_space)):
try:
kernel_time = benchmark_config(

View File

@@ -8,8 +8,8 @@ import regex as re
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
# --------------------------------------------------------------------------- #
_TORCH_CUDA_PATTERNS = [
r"\btorch\.cuda\.empty_cache\b",
r"\btorch\.cuda\.synchronize\b",
r"\btorch\.cuda\.(empty_cache|synchronize|device\()\b",
r"\bwith\btorch\.cuda\.device\b",
]
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}

View File

@@ -133,9 +133,7 @@ class PyNcclCommunicator:
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
with torch.accelerator.device_index(device.index):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank
)

View File

@@ -218,7 +218,7 @@ class P2pNcclEngine:
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
sock.send(msgpack.dumps(data))
with torch.cuda.device(self.device):
with torch.accelerator.device_index(self.device.index):
rank = 0
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(2, unique_id, rank)
@@ -377,7 +377,7 @@ class P2pNcclEngine:
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(bytes(data["unique_id"]))
with torch.cuda.device(self.device):
with torch.accelerator.device_index(self.device.index):
rank = 1
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(

View File

@@ -105,7 +105,7 @@ def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
break
if tensor is not None:
ctx = torch.cuda.device(tensor.device.index)
ctx = torch.accelerator.device_index(tensor.device.index)
else:
ctx = contextlib.nullcontext()

View File

@@ -119,7 +119,7 @@ def _layer_norm_fwd(
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
with torch.cuda.device(x.device.index):
with torch.accelerator.device_index(x.device.index):
_layer_norm_fwd_1pass_kernel[grid](
x,
out,

View File

@@ -419,7 +419,7 @@ def selective_state_update(
and dt.stride(-1) == 0
and dt_bias.stride(-1) == 0
)
with torch.cuda.device(x.device.index):
with torch.accelerator.device_index(x.device.index):
_selective_scan_update_kernel[grid](
state,
x,

View File

@@ -185,7 +185,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtyp
* triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
nchunks * ngroups,
)
with torch.cuda.device(a.device.index):
with torch.accelerator.device_index(a.device.index):
_bmm_chunk_fwd_kernel[grid](
a_ptr=a,
b_ptr=b,

View File

@@ -323,7 +323,7 @@ def _chunk_cumsum_fwd(
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
)
grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"]))
with torch.cuda.device(dt.device.index):
with torch.accelerator.device_index(dt.device.index):
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
dt_ptr=dt,
A_ptr=A,
@@ -378,7 +378,7 @@ def _chunk_state_fwd(
nchunks,
nheads,
)
with torch.cuda.device(x.device.index):
with torch.accelerator.device_index(x.device.index):
_chunk_state_fwd_kernel[grid](
x_ptr=x,
b_ptr=B,

View File

@@ -120,7 +120,7 @@ def _state_passing_fwd(
)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads)
with torch.cuda.device(states.device.index):
with torch.accelerator.device_index(states.device.index):
_state_passing_fwd_kernel[grid](
states_ptr=states,
out_ptr=out,