diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 9ef825417..cf49232fd 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -626,7 +626,11 @@ class BenchmarkWorker: if visible_device != f"{self.device_id}": need_device_guard = True - with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): + with ( + torch.accelerator.device_index(self.device_id) + if need_device_guard + else nullcontext() + ): for idx, config in enumerate(tqdm(search_space)): try: kernel_time = benchmark_config( diff --git a/tools/pre_commit/check_torch_cuda.py b/tools/pre_commit/check_torch_cuda.py index 356650863..42cb0945b 100644 --- a/tools/pre_commit/check_torch_cuda.py +++ b/tools/pre_commit/check_torch_cuda.py @@ -8,8 +8,8 @@ import regex as re # Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx` # --------------------------------------------------------------------------- # _TORCH_CUDA_PATTERNS = [ - r"\btorch\.cuda\.empty_cache\b", - r"\btorch\.cuda\.synchronize\b", + r"\btorch\.cuda\.(empty_cache|synchronize|device\()\b", + r"\bwith\btorch\.cuda\.device\b", ] ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"} diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 44dc113e4..84a032541 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -133,9 +133,7 @@ class PyNcclCommunicator: assert isinstance(device, torch.device) self.device = device # nccl communicator and stream will use this device - # `torch.cuda.device` is a context manager that changes the - # current cuda device to the specified one - with torch.cuda.device(device): + with torch.accelerator.device_index(device.index): self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.world_size, self.unique_id, self.rank ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index 0e748db66..1c1410f39 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -218,7 +218,7 @@ class P2pNcclEngine: data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} sock.send(msgpack.dumps(data)) - with torch.cuda.device(self.device): + with torch.accelerator.device_index(self.device.index): rank = 0 with set_p2p_nccl_context(self.nccl_num_channels): comm: ncclComm_t = self.nccl.ncclCommInitRank(2, unique_id, rank) @@ -377,7 +377,7 @@ class P2pNcclEngine: data = msgpack.loads(message) if data["cmd"] == "NEW": unique_id = self.nccl.unique_id_from_bytes(bytes(data["unique_id"])) - with torch.cuda.device(self.device): + with torch.accelerator.device_index(self.device.index): rank = 1 with set_p2p_nccl_context(self.nccl_num_channels): comm: ncclComm_t = self.nccl.ncclCommInitRank( diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py index 18e17a511..f0ec1f7a6 100644 --- a/vllm/model_executor/layers/fla/ops/utils.py +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -105,7 +105,7 @@ def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: break if tensor is not None: - ctx = torch.cuda.device(tensor.device.index) + ctx = torch.accelerator.device_index(tensor.device.index) else: ctx = contextlib.nullcontext() diff --git a/vllm/model_executor/layers/mamba/ops/layernorm_gated.py b/vllm/model_executor/layers/mamba/ops/layernorm_gated.py index b592906c6..19db051cf 100644 --- a/vllm/model_executor/layers/mamba/ops/layernorm_gated.py +++ b/vllm/model_executor/layers/mamba/ops/layernorm_gated.py @@ -119,7 +119,7 @@ def _layer_norm_fwd( # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M, ngroups) - with torch.cuda.device(x.device.index): + with torch.accelerator.device_index(x.device.index): _layer_norm_fwd_1pass_kernel[grid]( x, out, diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 50778a990..22a99596a 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -419,7 +419,7 @@ def selective_state_update( and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0 ) - with torch.cuda.device(x.device.index): + with torch.accelerator.device_index(x.device.index): _selective_scan_update_kernel[grid]( state, x, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index ac5ffc10f..9b5901c38 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -185,7 +185,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtyp * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), nchunks * ngroups, ) - with torch.cuda.device(a.device.index): + with torch.accelerator.device_index(a.device.index): _bmm_chunk_fwd_kernel[grid]( a_ptr=a, b_ptr=b, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index ed60593f5..37532e6db 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -323,7 +323,7 @@ def _chunk_cumsum_fwd( nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 ) grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"])) - with torch.cuda.device(dt.device.index): + with torch.accelerator.device_index(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( dt_ptr=dt, A_ptr=A, @@ -378,7 +378,7 @@ def _chunk_state_fwd( nchunks, nheads, ) - with torch.cuda.device(x.device.index): + with torch.accelerator.device_index(x.device.index): _chunk_state_fwd_kernel[grid]( x_ptr=x, b_ptr=B, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 5c5cb9d37..bd33e7e49 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -120,7 +120,7 @@ def _state_passing_fwd( ) grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads) - with torch.cuda.device(states.device.index): + with torch.accelerator.device_index(states.device.index): _state_passing_fwd_kernel[grid]( states_ptr=states, out_ptr=out,