Fix CuTeDSL from_dlpack device mismatch inside CUDA graph capture

When capturing CUDA graphs on non-default GPUs, torch.cuda.current_device()
may not match the tensor's device. from_dlpack() checks this and fails.
Fix: set the current device to match the tensor's device before from_dlpack.

This enables graph capture on all 8 GPUs, not just cuda:0.
This commit is contained in:
2026-06-03 20:34:24 +00:00
parent 2661cebe9a
commit 87b6c9932b

View File

@@ -99,6 +99,10 @@ def warmup_compilation(num_experts, K_packed, N_packed, device,
)
def to_cute(t):
# Ensure current CUDA device matches tensor device (required for from_dlpack
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
if t.is_cuda and t.device.index != torch.cuda.current_device():
torch.cuda.set_device(t.device.index)
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
@@ -203,6 +207,10 @@ def run_nvfp4_grouped_gemm(
)
def to_cute(t):
# Ensure current CUDA device matches tensor device (required for from_dlpack
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
if t.is_cuda and t.device.index != torch.cuda.current_device():
torch.cuda.set_device(t.device.index)
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
@@ -250,6 +258,10 @@ def run_nvfp4_grouped_gemm(
# This is cheap (metadata only, no GPU work) and avoids stale
# references to tensors from previous calls that may have been freed.
def to_cute(t):
# Ensure current CUDA device matches tensor device (required for from_dlpack
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
if t.is_cuda and t.device.index != torch.cuda.current_device():
torch.cuda.set_device(t.device.index)
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
@@ -328,6 +340,10 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
)
def to_cute(t):
# Ensure current CUDA device matches tensor device (required for from_dlpack
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
if t.is_cuda and t.device.index != torch.cuda.current_device():
torch.cuda.set_device(t.device.index)
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
@@ -425,6 +441,10 @@ def run_fused_swiglu_grouped_gemm(
)
def to_cute(t):
# Ensure current CUDA device matches tensor device (required for from_dlpack
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
if t.is_cuda and t.device.index != torch.cuda.current_device():
torch.cuda.set_device(t.device.index)
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
@@ -466,6 +486,10 @@ def run_fused_swiglu_grouped_gemm(
workspace = entry['workspace']
def to_cute(t):
# Ensure current CUDA device matches tensor device (required for from_dlpack
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
if t.is_cuda and t.device.index != torch.cuda.current_device():
torch.cuda.set_device(t.device.index)
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))