diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index a556cfd8..4e38a3ec 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -136,13 +136,18 @@ __global__ void remap_sf_to_cutlass_kernel( int m = 0, k_sf = 0; if constexpr (R == 6) { - // K-major SfAtom: ((32, 4, k_tiles), (SFVecSize, 4, m_tiles)) - // Flattened: (inner_k, sub_k, tile_k, inner_m, sub_m, tile_m) - k_sf = cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128; + // K-major SfAtom tiled with Step<(2,1,3)>: + // Shape: ((32, 4, K_tiles), (SFVecSize, 4, M_tiles)) + // First group covers K dimension, second covers M dimension + // get<0..2> = K group, get<3..5> = M group + // K element index = get<0> + get<1>*32 + get<2>*128 + // k_sf = K_element_index / SFVecSize + // M element index = get<3> + get<4>*SFVecSize + get<5>*(SFVecSize*4) + k_sf = (cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128) / InputSFVectorSize; m = cute::get<3>(flat) + cute::get<4>(flat) * InputSFVectorSize + cute::get<5>(flat) * (InputSFVectorSize * 4); } else if constexpr (R == 8) { - // With batch dimension: ((32, 4, k_tiles), (SFVecSize, 4, m_tiles), (1,)) - k_sf = cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128; + // With batch/L dimension: ((32, 4, K_tiles), (SFVecSize, 4, M_tiles), (1, 1, L)) + k_sf = (cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128) / InputSFVectorSize; m = cute::get<3>(flat) + cute::get<4>(flat) * InputSFVectorSize + cute::get<5>(flat) * (InputSFVectorSize * 4); } else { m = 0; k_sf = 0;