diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 6c0178bf..0cdc82c6 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -98,59 +98,40 @@ using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// -// Scale factor remap kernel using CuTe layout operations +// Scale factor remap kernel: row-major source -> CUTLASS interleaved layout ///////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void remap_sf_to_cutlass_kernel( - const cutlass::float_ue4m3_t* __restrict__ src, - cutlass::float_ue4m3_t* __restrict__ dst, - LayoutSF layout_sf, - int MN, int K_sf + const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major + cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout (zero-initialized) + LayoutSF layout_sf, // CuTe layout for dst + int MN, int K_sf, // Source dimensions (in SF groups) + int SFVecSize // Elements per SF group (16 for NVFP4) ) { - int dst_idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = cute::size(layout_sf); - if (dst_idx >= total) return; + // SOURCE-ITERATING: for each (m, k_group) in the row-major source, + // compute the CUTLASS destination offset using layout_sf(m, k_elem). + // + // CuTe's layout operator() accepts "natural" coordinates whose rank matches + // the top-level mode count (rank-2 here: M and K). It decomposes flat ints + // into nested sub-coordinates automatically. No idx2crd, no flatten, no rank inspection. + // + // The K coordinate is in ELEMENT-SPACE (not group-space): + // k_group=3 with SFVecSize=16 -> k_elem=48 in the layout. + // All 16 elements within a group share the same SF byte, so we iterate + // over groups (not elements) for efficiency — one write per SF value. + int src_idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = MN * K_sf; + if (src_idx >= total) return; - auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride()); - auto flat = cute::flatten(coord); - - constexpr int flat_rank = cute::rank_v; - - // Debug: print flat_rank and first few coordinates (only for idx 0) - if (dst_idx == 0) { - printf("[remap] flat_rank=%d, MN=%d, K_sf=%d, total=%d\n", flat_rank, MN, K_sf, total); - printf("[remap] layout shape: "); - cute::print(layout_sf.shape()); - printf("\n"); - } - - int m, k_sf_out; - - if constexpr (flat_rank == 6) { - // Full nested: (inner_M, sub_M, tile_M, inner_K, sub_K, tile_K) - int inner_m = cute::get<0>(flat); - int sub_m = cute::get<1>(flat); - int tile_m = cute::get<2>(flat); - // get<3> = inner_k (within one SF group — same byte) - int sub_k = cute::get<4>(flat); - int tile_k = cute::get<5>(flat); - - m = tile_m * 128 + inner_m * 4 + sub_m; - k_sf_out = tile_k * 4 + sub_k; - } else if constexpr (flat_rank == 4) { - int inner_m = cute::get<0>(flat); - int sub_m = cute::get<1>(flat); - int sub_k = cute::get<2>(flat); - int tile_k = cute::get<3>(flat); - - m = inner_m * 4 + sub_m; - k_sf_out = tile_k * 4 + sub_k; - } else { - m = cute::get<0>(flat); - k_sf_out = cute::get<1>(flat); - } + int m = src_idx / K_sf; + int k_group = src_idx % K_sf; + int k_elem = k_group * SFVecSize; + // layout_sf(m, k_elem) -> CUTLASS linear index + // CuTe handles the nested shape decomposition internally. + dst[layout_sf(m, k_elem)] = src[src_idx]; +} if (m < MN && k_sf_out < K_sf) { dst[dst_idx] = src[m * K_sf + k_sf_out]; } @@ -197,11 +178,13 @@ int cutlass_nvfp4_gemm_run( cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream); int block = 256; - // Grid size based on DEST (CUTLASS buffer) size since kernel iterates dest - remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>( - static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf); - remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>( - static_cast(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf); + // Grid size based on SOURCE elements (MN * K_sf) since kernel iterates source + int sfa_src = M * K_sf; + int sfb_src = N * K_sf; + remap_sf_to_cutlass_kernel<<<(sfa_src + block - 1) / block, block, 0, stream>>>( + static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, InputSFVectorSize); + remap_sf_to_cutlass_kernel<<<(sfb_src + block - 1) / block, block, 0, stream>>>( + static_cast(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, InputSFVectorSize); typename Gemm::Arguments arguments { cutlass::gemm::GemmUniversalMode::kGemm,