diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 93806034..00a92d43 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -131,46 +131,63 @@ __global__ void remap_sf_to_cutlass_kernel( printf("[remap] flat_rank=%d MN=%d K_sf=%d total=%d\n", R, MN, K_sf, total); } + // Debug: print first few coordinates + if (dst_idx < 3) { + printf("[remap] idx=%d flat_rank=%d vals=", dst_idx, R); + if constexpr (R >= 8) { + printf("%d,%d,%d,%d,%d,%d,%d,%d", + int(cute::get<0>(flat)), int(cute::get<1>(flat)), + int(cute::get<2>(flat)), int(cute::get<3>(flat)), + int(cute::get<4>(flat)), int(cute::get<5>(flat)), + int(cute::get<6>(flat)), int(cute::get<7>(flat))); + } + printf("\n"); + } + int m = 0, k_sf = 0; - if constexpr (R == 6) { - // Full 6: (inner_m, sub_m, tile_m, inner_k, sub_k, tile_k) - int inner_m = cute::get<0>(flat); - int sub_m = cute::get<1>(flat); - int tile_m = cute::get<2>(flat); - // get<3> = inner_k (within SF group, ignored for k_sf) - int sub_k = cute::get<4>(flat); - int tile_k = cute::get<5>(flat); - m = tile_m * 128 + inner_m * 4 + sub_m; - k_sf = tile_k * 4 + sub_k; - } else if constexpr (R == 5) { - // 5: maybe (inner_m, sub_m, tile_m, sub_k, tile_k) or similar - int inner_m = cute::get<0>(flat); - int sub_m = cute::get<1>(flat); - int tile_m = cute::get<2>(flat); - int sub_k = cute::get<3>(flat); - int tile_k = cute::get<4>(flat); - m = tile_m * 128 + inner_m * 4 + sub_m; - k_sf = tile_k * 4 + sub_k; - } else if constexpr (R == 4) { - // 4: (inner_m, sub_m, sub_k, tile_k) — small M fits in 1 tile - int inner_m = cute::get<0>(flat); - int sub_m = cute::get<1>(flat); - int sub_k = cute::get<2>(flat); - int tile_k = cute::get<3>(flat); - m = inner_m * 4 + sub_m; - k_sf = tile_k * 4 + sub_k; - } else if constexpr (R == 3) { - // 3: maybe (inner_m, sub_m, k_combined) - int inner_m = cute::get<0>(flat); - int sub_m = cute::get<1>(flat); - int k_comb = cute::get<2>(flat); - m = inner_m * 4 + sub_m; - k_sf = k_comb; - } else if constexpr (R == 2) { - // 2: flat (m, k) — no nesting - m = cute::get<0>(flat); - k_sf = cute::get<1>(flat); + if constexpr (R == 8) { + // 8 flattened coordinates: 4 for M, 4 for K + // M: (inner_32, inner_4, tile_interleave, tile_m) + // K: (inner_16, inner_4, tile_interleave_k, tile_k) + // + // SfAtom shape (32,4) with stride (16,4): + // inner_32 * 16 + inner_4 * 4 = local offset within atom + // m = inner_32 * 4 + inner_4 + (tile_interleave + tile_m * 2) * 128 + // (step=2 means M tiles interleave, so tile_interleave is 0 or 1) + // k_sf = inner_16 is within one SF group (0..15), ignored + // inner_4 is the sub-K (0..3) + // k_sf = inner_4 + tile_interleave_k * 4 + tile_k * 4 + // Actually inner_16 covers 16 elements = 1 SF group, so it doesn't + // contribute to k_sf. k_sf = sub_k + (tile_k * 4) + // But wait, with the tile_interleave_k, it might be: + // k_sf = inner_4 + tile_interleave_k * 4 + tile_k * 8 + // + // Let me just compute m and k_sf empirically from the flat values: + int f0 = cute::get<0>(flat); + int f1 = cute::get<1>(flat); + int f2 = cute::get<2>(flat); + int f3 = cute::get<3>(flat); + int f4 = cute::get<4>(flat); + int f5 = cute::get<5>(flat); + int f6 = cute::get<6>(flat); + int f7 = cute::get<7>(flat); + + // M = f0..f3, K = f4..f7 + // f0 = inner_32 (0..31), f1 = inner_4 (0..3) + // These two give local_m = f0 * 4 + f1 (0..127) + // f2, f3 are M tiling: with Step<2>, f2 is interleave (0..1), f3 is tile (0..) + // m = (f3 * 2 + f2) * 128 + f0 * 4 + f1 + m = (f3 * 2 + f2) * 128 + f0 * 4 + f1; + + // f4 = inner_16 (0..15, within one SF group, doesn't contribute to k_sf) + // f5 = inner_4_k (0..3, K sub-index within atom) + // f6, f7 are K tiling + // k_sf = f5 + f6 * 4 + f7 * 8 (guessing the tiling pattern) + // Actually with Step<1> on K, the tiling is simpler: + // k_sf = f5 + (f6 + f7 * n_k_interleave) * 4 + // Let me try: k_sf = f5 + f6 * 4 (assuming f7 is outer tile) + k_sf = f5 + f6 * 4; // This may need adjustment based on printf output } else { // Fallback: index 0 and 1 m = 0; k_sf = 0;