diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 1ce952b7..5917b1f7 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -133,38 +133,12 @@ __global__ void remap_sf_to_cutlass_kernel( int mn = tid / K_sf; int k_sf = tid % K_sf; - constexpr int LayoutRank = cute::rank_v; + // Logical K element coordinate, not compact scale-factor coordinate. + int k_elem = k_sf * 16; - // Decompose flat (mn, k_sf) into hierarchical coordinates matching the atom layout: - // Shape: ((32, 4, mn_tiles), (SFVecSize, 4, k_tiles), ...) - // First group: mn = m0 + 32*m1 + 128*mt where m0 in [0,32), m1 in [0,4), mt = mn/128 - // Second group: k_sf = k1 + 4*kt where k1 = k_sf % 4, kt = k_sf / 4 - // (k0 is the stride-0 inner SF vector index — always 0 for the first element) - int m0 = mn % 32; - int m1 = (mn / 32) % 4; - int mt = mn / 128; - int k0 = 0; // stride-0, within SF vector - int k1 = k_sf % 4; - int kt = k_sf / 4; + int dst_idx = layout_sf(cute::make_coord(mn, k_elem, 0)); - int dst_idx = 0; - if constexpr (LayoutRank == 2) { - auto hier_coord = cute::make_coord( - cute::make_coord(cute::make_coord(m0, m1), mt), - cute::make_coord(cute::make_coord(k0, k1), kt) - ); - dst_idx = layout_sf(hier_coord); - } else if constexpr (LayoutRank == 3) { - auto hier_coord = cute::make_coord( - cute::make_coord(cute::make_coord(m0, m1), mt), - cute::make_coord(cute::make_coord(k0, k1), kt), - 0 - ); - dst_idx = layout_sf(hier_coord); - } - - int src_idx = mn * src_stride_mn + k_sf * src_stride_ksf; - dst[dst_idx] = src[src_idx]; + dst[dst_idx] = src[mn * src_stride_mn + k_sf * src_stride_ksf]; } /////////////////////////////////////////////////////////////////////////////////////////////////