Fix CuTeDSL from_dlpack device mismatch in CUDA graph capture (v3)
Patch torch.cuda.current_device to return the tensor's device index during from_dlpack calls inside CUDA graph capture. This bypasses the device check in __dlpack__ without changing the CUDA stream (which caused 'Capture must end on the same stream' in v1) and without triggering a cross-device copy (which caused 'Cannot copy between CPU and CUDA tensors' in v2).
This commit is contained in:
@@ -27,30 +27,6 @@ from dsv4.ops.layouts import (
|
||||
)
|
||||
|
||||
|
||||
class _DLPatchTensor:
|
||||
"""Wrapper that patches __dlpack__ to force dl_device during CUDA graph capture.
|
||||
|
||||
from_dlpack checks torch.cuda.current_device() against the tensor's device.
|
||||
Inside CUDA graph capture on non-default GPUs, current_device() may not match.
|
||||
This wrapper forces dl_device in __dlpack__, bypassing the device check.
|
||||
|
||||
DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream
|
||||
and causes 'Capture must end on the same stream it began on'.
|
||||
"""
|
||||
__slots__ = ('_t',)
|
||||
def __init__(self, t):
|
||||
self._t = t
|
||||
def __dlpack__(self, *args, **kwargs):
|
||||
kwargs['dl_device'] = (1, self._t.device.index) # kDLCUDA=1
|
||||
try:
|
||||
return self._t.__dlpack__(*args, **kwargs)
|
||||
except TypeError:
|
||||
return self._t.__dlpack__(*args, **{k: v for k, v in kwargs.items() if k != 'dl_device'})
|
||||
def __dlpack_device__(self):
|
||||
return (1, self._t.device.index)
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._t, name)
|
||||
|
||||
|
||||
# Cache compiled kernels + pre-allocated workspace by cache_key
|
||||
# Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int}
|
||||
@@ -127,11 +103,13 @@ def warmup_compilation(num_experts, K_packed, N_packed, device,
|
||||
def to_cute(t):
|
||||
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
|
||||
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
|
||||
# We wrap the tensor to force dl_device in __dlpack__, bypassing the check.
|
||||
# DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream.
|
||||
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
|
||||
t = _DLPatchTensor(t)
|
||||
# We temporarily patch current_device to return the tensor's device index.
|
||||
# This is safe because during graph capture, the device is logically fixed.
|
||||
_orig_cd = torch.cuda.current_device
|
||||
if t.is_cuda and t.device.index != _orig_cd():
|
||||
torch.cuda.current_device = lambda: t.device.index
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
torch.cuda.current_device = _orig_cd
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
a_c = to_cute(mat_a)
|
||||
@@ -235,9 +213,11 @@ def run_nvfp4_grouped_gemm(
|
||||
)
|
||||
|
||||
def to_cute(t):
|
||||
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
|
||||
t = _DLPatchTensor(t)
|
||||
_orig_cd = torch.cuda.current_device
|
||||
if t.is_cuda and t.device.index != _orig_cd():
|
||||
torch.cuda.current_device = lambda: t.device.index
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
torch.cuda.current_device = _orig_cd
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
a_c = to_cute(mat_a)
|
||||
@@ -286,11 +266,13 @@ def run_nvfp4_grouped_gemm(
|
||||
def to_cute(t):
|
||||
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
|
||||
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
|
||||
# We wrap the tensor to force dl_device in __dlpack__, bypassing the check.
|
||||
# DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream.
|
||||
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
|
||||
t = _DLPatchTensor(t)
|
||||
# We temporarily patch current_device to return the tensor's device index.
|
||||
# This is safe because during graph capture, the device is logically fixed.
|
||||
_orig_cd = torch.cuda.current_device
|
||||
if t.is_cuda and t.device.index != _orig_cd():
|
||||
torch.cuda.current_device = lambda: t.device.index
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
torch.cuda.current_device = _orig_cd
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
a_c = to_cute(mat_a)
|
||||
@@ -370,11 +352,13 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
|
||||
def to_cute(t):
|
||||
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
|
||||
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
|
||||
# We wrap the tensor to force dl_device in __dlpack__, bypassing the check.
|
||||
# DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream.
|
||||
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
|
||||
t = _DLPatchTensor(t)
|
||||
# We temporarily patch current_device to return the tensor's device index.
|
||||
# This is safe because during graph capture, the device is logically fixed.
|
||||
_orig_cd = torch.cuda.current_device
|
||||
if t.is_cuda and t.device.index != _orig_cd():
|
||||
torch.cuda.current_device = lambda: t.device.index
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
torch.cuda.current_device = _orig_cd
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
a_c = to_cute(mat_a)
|
||||
@@ -471,9 +455,11 @@ def run_fused_swiglu_grouped_gemm(
|
||||
)
|
||||
|
||||
def to_cute(t):
|
||||
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
|
||||
t = _DLPatchTensor(t)
|
||||
_orig_cd = torch.cuda.current_device
|
||||
if t.is_cuda and t.device.index != _orig_cd():
|
||||
torch.cuda.current_device = lambda: t.device.index
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
torch.cuda.current_device = _orig_cd
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
a_c = to_cute(mat_a)
|
||||
@@ -516,11 +502,13 @@ def run_fused_swiglu_grouped_gemm(
|
||||
def to_cute(t):
|
||||
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
|
||||
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
|
||||
# We wrap the tensor to force dl_device in __dlpack__, bypassing the check.
|
||||
# DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream.
|
||||
if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing():
|
||||
t = _DLPatchTensor(t)
|
||||
# We temporarily patch current_device to return the tensor's device index.
|
||||
# This is safe because during graph capture, the device is logically fixed.
|
||||
_orig_cd = torch.cuda.current_device
|
||||
if t.is_cuda and t.device.index != _orig_cd():
|
||||
torch.cuda.current_device = lambda: t.device.index
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
torch.cuda.current_device = _orig_cd
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
a_c = to_cute(mat_a)
|
||||
|
||||
Reference in New Issue
Block a user