diff --git a/dsv4/ops/gemm_runner.py b/dsv4/ops/gemm_runner.py index a7d24b68..5071e51f 100644 --- a/dsv4/ops/gemm_runner.py +++ b/dsv4/ops/gemm_runner.py @@ -99,6 +99,10 @@ def warmup_compilation(num_experts, K_packed, N_packed, device, ) def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) @@ -203,6 +207,10 @@ def run_nvfp4_grouped_gemm( ) def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) @@ -250,6 +258,10 @@ def run_nvfp4_grouped_gemm( # This is cheap (metadata only, no GPU work) and avoids stale # references to tensors from previous calls that may have been freed. def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) @@ -328,6 +340,10 @@ def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device, ) def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) @@ -425,6 +441,10 @@ def run_fused_swiglu_grouped_gemm( ) def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) @@ -466,6 +486,10 @@ def run_fused_swiglu_grouped_gemm( workspace = entry['workspace'] def to_cute(t): + # Ensure current CUDA device matches tensor device (required for from_dlpack + # inside CUDA graph capture, where torch.cuda.current_device() may not match) + if t.is_cuda and t.device.index != torch.cuda.current_device(): + torch.cuda.set_device(t.device.index) ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))