Fix CuTeDSL from_dlpack device mismatch inside CUDA graph capture
When capturing CUDA graphs on non-default GPUs, torch.cuda.current_device() may not match the tensor's device. from_dlpack() checks this and fails. Fix: set the current device to match the tensor's device before from_dlpack. This enables graph capture on all 8 GPUs, not just cuda:0.
This commit is contained in:
@@ -99,6 +99,10 @@ def warmup_compilation(num_experts, K_packed, N_packed, device,
|
||||
)
|
||||
|
||||
def to_cute(t):
|
||||
# Ensure current CUDA device matches tensor device (required for from_dlpack
|
||||
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
|
||||
if t.is_cuda and t.device.index != torch.cuda.current_device():
|
||||
torch.cuda.set_device(t.device.index)
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
@@ -203,6 +207,10 @@ def run_nvfp4_grouped_gemm(
|
||||
)
|
||||
|
||||
def to_cute(t):
|
||||
# Ensure current CUDA device matches tensor device (required for from_dlpack
|
||||
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
|
||||
if t.is_cuda and t.device.index != torch.cuda.current_device():
|
||||
torch.cuda.set_device(t.device.index)
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
@@ -250,6 +258,10 @@ def run_nvfp4_grouped_gemm(
|
||||
# This is cheap (metadata only, no GPU work) and avoids stale
|
||||
# references to tensors from previous calls that may have been freed.
|
||||
def to_cute(t):
|
||||
# Ensure current CUDA device matches tensor device (required for from_dlpack
|
||||
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
|
||||
if t.is_cuda and t.device.index != torch.cuda.current_device():
|
||||
torch.cuda.set_device(t.device.index)
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
@@ -328,6 +340,10 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
|
||||
)
|
||||
|
||||
def to_cute(t):
|
||||
# Ensure current CUDA device matches tensor device (required for from_dlpack
|
||||
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
|
||||
if t.is_cuda and t.device.index != torch.cuda.current_device():
|
||||
torch.cuda.set_device(t.device.index)
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
@@ -425,6 +441,10 @@ def run_fused_swiglu_grouped_gemm(
|
||||
)
|
||||
|
||||
def to_cute(t):
|
||||
# Ensure current CUDA device matches tensor device (required for from_dlpack
|
||||
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
|
||||
if t.is_cuda and t.device.index != torch.cuda.current_device():
|
||||
torch.cuda.set_device(t.device.index)
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
@@ -466,6 +486,10 @@ def run_fused_swiglu_grouped_gemm(
|
||||
workspace = entry['workspace']
|
||||
|
||||
def to_cute(t):
|
||||
# Ensure current CUDA device matches tensor device (required for from_dlpack
|
||||
# inside CUDA graph capture, where torch.cuda.current_device() may not match)
|
||||
if t.is_cuda and t.device.index != torch.cuda.current_device():
|
||||
torch.cuda.set_device(t.device.index)
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user