diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 0dc6303a..5efce57f 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -137,10 +137,18 @@ __global__ void remap_sf_to_cutlass_kernel( constexpr int LayoutRank = cute::rank_v; int dst_idx = 0; + // Use crd2idx to map flat logical coordinate to CUTLASS physical index. + // The layout's logical shape is ((32,4,...), (SFVecSize,4,...)) for K-major, + // but CuTe accepts flat coordinates and internally maps them. + // For SFA: logical coord is (m, k_element, l) + // For SFB: logical coord is (n, k_element, l) + // k_element = k_sf * InputSFVectorSize (convert SF group to K element) if constexpr (LayoutRank == 2) { - dst_idx = layout_sf(cute::make_coord(mn, k_sf * InputSFVectorSize)); + auto logical_coord = cute::make_coord(mn, k_sf * InputSFVectorSize); + dst_idx = cute::crd2idx(logical_coord, layout_sf); } else if constexpr (LayoutRank == 3) { - dst_idx = layout_sf(cute::make_coord(mn, k_sf * InputSFVectorSize, 0)); + auto logical_coord = cute::make_coord(mn, k_sf * InputSFVectorSize, 0); + dst_idx = cute::crd2idx(logical_coord, layout_sf); } int src_idx = mn * src_stride_mn + k_sf * src_stride_ksf;