From 904fc37ad8cdbcfd8a3e31fc6f2cbab036442b7b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 10:50:26 +0000 Subject: [PATCH] Fix: use idx2crd instead of get_coord for CuTe layout coordinate lookup --- .../cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index cc8ec61c..0452827d 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -120,11 +120,15 @@ __global__ void remap_sf_to_cutlass_kernel( // The key: the layout was created by tile_to_shape(SfAtom{}, make_shape(MN, K), Step<_2,_1>{}) // So the coordinate (c0, c1) corresponds to row c0 and K-group c1 in the original tensor. - auto coord = layout_sf.get_coord(idx); + // idx2crd converts a linear index to a logical coordinate in the layout's space + auto coord = cute::idx2crd(idx, layout_sf.shape(), layout_sf.stride()); - // Extract the (m, k) pair from the coordinate tuple - int m = cute::get<0>(cute::flatten(coord)); // Row index - int k = cute::get<1>(cute::flatten(coord)); // K-group index + // The coordinate is a nested tuple. For a 2D layout it's (c0, c1) + // where c0 = row/M index and c1 = col/K-group index. + // Flatten the nested tuple to extract the two logical coordinates. + auto flat = cute::flatten(coord); + int m = cute::get<0>(flat); + int k = cute::get<1>(flat); if (m < MN && k < K_sf) { dst[idx] = src[m * K_sf + k];