diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 860e41f7..a556cfd8 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -149,24 +149,8 @@ __global__ void remap_sf_to_cutlass_kernel( } if (m < MN && k_sf < K_sf) { - // SFA: source is (MN, K_sf) row-major → src[m * K_sf + k_sf] - // SFB: source is (K_sf, MN) row-major → src[k_sf * MN + m] dst[dst_idx] = col_major_src ? src[k_sf * MN + m] : src[m * K_sf + k_sf]; } - - // Diagnostic: print first 10 coordinate mappings (host-side, only dst_idx 0..9) - if (dst_idx < 10) { - printf("[SF-REMAP] dst=%d rank=%d", dst_idx, R); - if constexpr (R >= 1) printf(" f0=%d", (int)cute::get<0>(flat)); - if constexpr (R >= 2) printf(" f1=%d", (int)cute::get<1>(flat)); - if constexpr (R >= 3) printf(" f2=%d", (int)cute::get<2>(flat)); - if constexpr (R >= 4) printf(" f3=%d", (int)cute::get<3>(flat)); - if constexpr (R >= 5) printf(" f4=%d", (int)cute::get<4>(flat)); - if constexpr (R >= 6) printf(" f5=%d", (int)cute::get<5>(flat)); - if constexpr (R >= 7) printf(" f6=%d", (int)cute::get<6>(flat)); - if constexpr (R >= 8) printf(" f7=%d", (int)cute::get<7>(flat)); - printf(" -> m=%d k_sf=%d col_major=%d MN=%d K_sf=%d\n", m, k_sf, col_major_src, MN, K_sf); - } } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -345,46 +329,4 @@ extern "C" int cutlass_nvfp4_gemm_run_prepacked_sfb( } // extern "C" -// Diagnostic: verify SF coordinate mapping -extern "C" int cutlass_nvfp4_gemm_diag_sf_layout( - int M, int N, int K, - int* out_cosize, int* out_size, int* out_rank, - int* out_shape, int* out_stride, // arrays of 8 ints - int* out_coord_map, // flat (dst_idx -> (m, k_sf)) pairs, max 256 entries - int max_entries -) { - using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; - LayoutSFA layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); - - *out_cosize = cute::cosize(layout_SFA); - *out_size = cute::size(layout_SFA); - - auto shape = layout_SFA.shape(); - auto stride = layout_SFA.stride(); - *out_rank = cute::rank(shape); - - // Write shape and stride - for (int i = 0; i < 8; i++) { - out_shape[i] = (i < *out_rank) ? (int)cute::get<0>(cute::slice_and_offset(shape, stride)) : 0; - out_stride[i] = 0; - } - - // Write coordinate mapping - int total = std::min(*out_cosize, max_entries); - for (int i = 0; i < total; i++) { - auto coord = cute::idx2crd(i, layout_SFA.shape(), layout_SFA.stride()); - auto flat = cute::flatten(coord); - constexpr int R = cute::rank_v; - int m = 0, k_sf = 0; - if constexpr (R == 8) { - m = cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128; - k_sf = cute::get<4>(flat) + cute::get<5>(flat) * 4; - } - out_coord_map[i * 2] = m; - out_coord_map[i * 2 + 1] = k_sf; - } - - return total; -} - #endif