diff --git a/dsv4/ops/gemm_runner.py b/dsv4/ops/gemm_runner.py index 0c9a67e5..e0711628 100644 --- a/dsv4/ops/gemm_runner.py +++ b/dsv4/ops/gemm_runner.py @@ -27,30 +27,6 @@ from dsv4.ops.layouts import ( ) -class _DLPatchTensor: - """Wrapper that patches __dlpack__ to force dl_device during CUDA graph capture. - - from_dlpack checks torch.cuda.current_device() against the tensor's device. - Inside CUDA graph capture on non-default GPUs, current_device() may not match. - This wrapper forces dl_device in __dlpack__, bypassing the device check. - - DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream - and causes 'Capture must end on the same stream it began on'. - """ - __slots__ = ('_t',) - def __init__(self, t): - self._t = t - def __dlpack__(self, *args, **kwargs): - kwargs['dl_device'] = (1, self._t.device.index) # kDLCUDA=1 - try: - return self._t.__dlpack__(*args, **kwargs) - except TypeError: - return self._t.__dlpack__(*args, **{k: v for k, v in kwargs.items() if k != 'dl_device'}) - def __dlpack_device__(self): - return (1, self._t.device.index) - def __getattr__(self, name): - return getattr(self._t, name) - # Cache compiled kernels + pre-allocated workspace by cache_key # Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int} @@ -127,11 +103,13 @@ def warmup_compilation(num_experts, K_packed, N_packed, device, def to_cute(t): # Fix: from_dlpack checks torch.cuda.current_device() against tensor device. # Inside CUDA graph capture on non-default GPUs, current_device() may not match. - # We wrap the tensor to force dl_device in __dlpack__, bypassing the check. - # DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream. - if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing(): - t = _DLPatchTensor(t) + # We temporarily patch current_device to return the tensor's device index. + # This is safe because during graph capture, the device is logically fixed. + _orig_cd = torch.cuda.current_device + if t.is_cuda and t.device.index != _orig_cd(): + torch.cuda.current_device = lambda: t.device.index ct = cutlass_torch.from_dlpack(t) + torch.cuda.current_device = _orig_cd return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) a_c = to_cute(mat_a) @@ -235,9 +213,11 @@ def run_nvfp4_grouped_gemm( ) def to_cute(t): - if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing(): - t = _DLPatchTensor(t) + _orig_cd = torch.cuda.current_device + if t.is_cuda and t.device.index != _orig_cd(): + torch.cuda.current_device = lambda: t.device.index ct = cutlass_torch.from_dlpack(t) + torch.cuda.current_device = _orig_cd return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) a_c = to_cute(mat_a) @@ -286,11 +266,13 @@ def run_nvfp4_grouped_gemm( def to_cute(t): # Fix: from_dlpack checks torch.cuda.current_device() against tensor device. # Inside CUDA graph capture on non-default GPUs, current_device() may not match. - # We wrap the tensor to force dl_device in __dlpack__, bypassing the check. - # DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream. - if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing(): - t = _DLPatchTensor(t) + # We temporarily patch current_device to return the tensor's device index. + # This is safe because during graph capture, the device is logically fixed. + _orig_cd = torch.cuda.current_device + if t.is_cuda and t.device.index != _orig_cd(): + torch.cuda.current_device = lambda: t.device.index ct = cutlass_torch.from_dlpack(t) + torch.cuda.current_device = _orig_cd return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) a_c = to_cute(mat_a) @@ -370,11 +352,13 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device, def to_cute(t): # Fix: from_dlpack checks torch.cuda.current_device() against tensor device. # Inside CUDA graph capture on non-default GPUs, current_device() may not match. - # We wrap the tensor to force dl_device in __dlpack__, bypassing the check. - # DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream. - if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing(): - t = _DLPatchTensor(t) + # We temporarily patch current_device to return the tensor's device index. + # This is safe because during graph capture, the device is logically fixed. + _orig_cd = torch.cuda.current_device + if t.is_cuda and t.device.index != _orig_cd(): + torch.cuda.current_device = lambda: t.device.index ct = cutlass_torch.from_dlpack(t) + torch.cuda.current_device = _orig_cd return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) a_c = to_cute(mat_a) @@ -471,9 +455,11 @@ def run_fused_swiglu_grouped_gemm( ) def to_cute(t): - if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing(): - t = _DLPatchTensor(t) + _orig_cd = torch.cuda.current_device + if t.is_cuda and t.device.index != _orig_cd(): + torch.cuda.current_device = lambda: t.device.index ct = cutlass_torch.from_dlpack(t) + torch.cuda.current_device = _orig_cd return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) a_c = to_cute(mat_a) @@ -516,11 +502,13 @@ def run_fused_swiglu_grouped_gemm( def to_cute(t): # Fix: from_dlpack checks torch.cuda.current_device() against tensor device. # Inside CUDA graph capture on non-default GPUs, current_device() may not match. - # We wrap the tensor to force dl_device in __dlpack__, bypassing the check. - # DO NOT use torch.cuda.set_device() inside graph capture — it changes the stream. - if hasattr(torch.cuda, 'is_current_stream_capturing') and torch.cuda.is_current_stream_capturing(): - t = _DLPatchTensor(t) + # We temporarily patch current_device to return the tensor's device index. + # This is safe because during graph capture, the device is logically fixed. + _orig_cd = torch.cuda.current_device + if t.is_cuda and t.device.index != _orig_cd(): + torch.cuda.current_device = lambda: t.device.index ct = cutlass_torch.from_dlpack(t) + torch.cuda.current_device = _orig_cd return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) a_c = to_cute(mat_a)