diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 00a92d43..4053279e 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -131,17 +131,20 @@ __global__ void remap_sf_to_cutlass_kernel( printf("[remap] flat_rank=%d MN=%d K_sf=%d total=%d\n", R, MN, K_sf, total); } - // Debug: print first few coordinates - if (dst_idx < 3) { - printf("[remap] idx=%d flat_rank=%d vals=", dst_idx, R); + // Debug: print specific indices to understand the coordinate decomposition + // Print idx 0, 1, 4, 16, 128, 512, 512*448, 128*448 + if (dst_idx == 0 || dst_idx == 1 || dst_idx == 4 || dst_idx == 16 || + dst_idx == 128 || dst_idx == 512 || dst_idx == 512*448 || + dst_idx == 128*448 || dst_idx == 4 || dst_idx == 3) { + printf("[remap] idx=%d", dst_idx); if constexpr (R >= 8) { - printf("%d,%d,%d,%d,%d,%d,%d,%d", + printf(" f=%d,%d,%d,%d,%d,%d,%d,%d", int(cute::get<0>(flat)), int(cute::get<1>(flat)), int(cute::get<2>(flat)), int(cute::get<3>(flat)), int(cute::get<4>(flat)), int(cute::get<5>(flat)), int(cute::get<6>(flat)), int(cute::get<7>(flat))); } - printf("\n"); + printf(" m=%d k=%d\n", m, k_sf); } int m = 0, k_sf = 0;