From 87b6c9932b0e2f79039ddaef929df14281f64c56 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 20:34:24 +0000 Subject: [PATCH] Fix CuTeDSL from_dlpack device mismatch inside CUDA graph capture When capturing CUDA graphs on non-default GPUs, torch.cuda.current_device() may not match the tensor's device. from_dlpack() checks this and fails. Fix: set the current device to match the tensor's device before from_dlpack. This enables graph capture on all 8 GPUs, not just cuda:0. --- dsv4/ops/gemm_runner.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/dsv4/ops/gemm_runner.py b/dsv4/ops/gemm_runner.py index a7d24b68..5071e51f 100644 --- a/dsv4/ops/gemm_runner.py +++ b/dsv4/ops/gemm_runner.py @@ -99,6 +99,10 @@ def warmup_compilation(num_experts, K_packed, N_packed, device, ) def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) @@ -203,6 +207,10 @@ def run_nvfp4_grouped_gemm( ) def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) @@ -250,6 +258,10 @@ def run_nvfp4_grouped_gemm( # This is cheap (metadata only, no GPU work) and avoids stale # references to tensors from previous calls that may have been freed. def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) @@ -328,6 +340,10 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device, ) def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) @@ -425,6 +441,10 @@ def run_fused_swiglu_grouped_gemm( ) def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) @@ -466,6 +486,10 @@ def run_fused_swiglu_grouped_gemm( workspace = entry['workspace'] def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))