Fix CuTeDSL from_dlpack device mismatch in CUDA graph capture (v3)

Patch torch.cuda.current_device to return the tensor's device index
during from_dlpack calls inside CUDA graph capture. This bypasses the
device check in __dlpack__ without changing the CUDA stream (which
caused 'Capture must end on the same stream' in v1) and without
triggering a cross-device copy (which caused 'Cannot copy between
CPU and CUDA tensors' in v2).
This commit is contained in:
2026-06-03 21:09:12 +00:00
parent 5c94dbbc37
commit 91c370360a

View File

@@ -27,30 +27,6 @@ from dsv4.ops.layouts import (
)
class _DLPatchTensor:
"""Wrapper that patches __dlpack__ to force dl_device during CUDA graph capture.
from_dlpack checks torch.cuda.current_device() against the tensor's device.
Inside CUDA graph capture on non-default GPUs, current_device() may not match.
This wrapper forces dl_device in __dlpack__, bypassing the device check.
DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream
and causes 'Capture must end on the same stream it began on'.
"""
__slots__ = ('_t',)
def __init__(self, t):
self._t = t
def __dlpack__(self, *args, **kwargs):
kwargs['dl_device'] = (1, self._t.device.index) # kDLCUDA=1
try:
return self._t.__dlpack__(*args, **kwargs)
except TypeError:
return self._t.__dlpack__(*args, **{k: v for k, v in kwargs.items() if k != 'dl_device'})
def __dlpack_device__(self):
return (1, self._t.device.index)
def __getattr__(self, name):
return getattr(self._t, name)
# Cache compiled kernels + pre-allocated workspace by cache_key
# Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int}
@@ -127,11 +103,13 @@ def warmup_compilation(num_experts, K_packed, N_packed, device,
def to_cute(t):
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
# We wrap the tensor to force dl_device in __dlpack__, bypassing the check.
# DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream.
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
t = _DLPatchTensor(t)
# We temporarily patch current_device to return the tensor's device index.
# This is safe because during graph capture, the device is logically fixed.
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
@@ -235,9 +213,11 @@ def run_nvfp4_grouped_gemm(
)
def to_cute(t):
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
t = _DLPatchTensor(t)
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
@@ -286,11 +266,13 @@ def run_nvfp4_grouped_gemm(
def to_cute(t):
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
# We wrap the tensor to force dl_device in __dlpack__, bypassing the check.
# DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream.
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
t = _DLPatchTensor(t)
# We temporarily patch current_device to return the tensor's device index.
# This is safe because during graph capture, the device is logically fixed.
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
@@ -370,11 +352,13 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
def to_cute(t):
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
# We wrap the tensor to force dl_device in __dlpack__, bypassing the check.
# DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream.
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
t = _DLPatchTensor(t)
# We temporarily patch current_device to return the tensor's device index.
# This is safe because during graph capture, the device is logically fixed.
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
@@ -471,9 +455,11 @@ def run_fused_swiglu_grouped_gemm(
)
def to_cute(t):
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
t = _DLPatchTensor(t)
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
@@ -516,11 +502,13 @@ def run_fused_swiglu_grouped_gemm(
def to_cute(t):
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
# We wrap the tensor to force dl_device in __dlpack__, bypassing the check.
# DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream.
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
t = _DLPatchTensor(t)
# We temporarily patch current_device to return the tensor's device index.
# This is safe because during graph capture, the device is logically fixed.
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)