diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index fbac8d4a..5ae26fe5 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -124,39 +124,40 @@ __global__ void remap_sf_to_cutlass_kernel( int MN, int K_sf, // Source dimensions (in SF groups) bool col_major_src = false // true if source is (K_sf, MN) row-major ) { - int dst_idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = cute::cosize(layout_sf); - if (dst_idx >= total) return; + // Forward-mapping approach: iterate over source indices (mn, k_sf), + // compute the CUTLASS dst index via layout_sf forward mapping. + // k_sf * InputSFVectorSize converts from SF-group index to logical K element + // coordinate, which is what the layout expects. + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int total = MN * K_sf; + if (tid >= total) return; - auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride()); - auto flat = cute::flatten(coord); + int mn, k_sf_val, src_idx; - constexpr int R = cute::rank_v; - - int mn = 0; - int k_sf = 0; - - if constexpr (R >= 6) { - // K-major SF atom: Shape, Shape> - // Stride, Stride<_0,_1>> - // SFA tiled as make_shape(M,K,...), SFB as make_shape(N,K,...) - // First flattened group is M/N: ((32,4), tile_mn) - mn = - int(cute::get<0>(flat)) + - 32 * int(cute::get<1>(flat)) + - 128 * int(cute::get<2>(flat)); - - // Second flattened group is K: (SFVecSize, 4, tile_k) - // get<3> is the stride-0 k-within-SF-vector coordinate — ignore it. - k_sf = - int(cute::get<4>(flat)) + - 4 * int(cute::get<5>(flat)); + if (col_major_src) { + // source is row-major (K_sf, MN), e.g. SFB stored as (K_sf, N) + k_sf_val = tid / MN; + mn = tid % MN; + src_idx = tid; + } else { + // source is row-major (MN, K_sf), e.g. SFA stored as (M, K_sf) + mn = tid / K_sf; + k_sf_val = tid % K_sf; + src_idx = tid; } - if (mn < MN && k_sf < K_sf) { - dst[dst_idx] = col_major_src - ? src[k_sf * MN + mn] // SFB source: (K_sf, N) - : src[mn * K_sf + k_sf]; // SFA source: (M, K_sf) + // Use layout forward mapping: source (mn, k_sf*16) -> dst_idx + constexpr int LayoutRank = cute::rank_v; + int dst_idx = 0; + + if constexpr (LayoutRank == 2) { + dst_idx = layout_sf(cute::make_coord(mn, k_sf_val * InputSFVectorSize)); + } else if constexpr (LayoutRank == 3) { + dst_idx = layout_sf(cute::make_coord(mn, k_sf_val * InputSFVectorSize, 0)); + } + + if (dst_idx >= 0 && dst_idx < cute::cosize(layout_sf)) { + dst[dst_idx] = src[src_idx]; } } @@ -197,9 +198,11 @@ int cutlass_nvfp4_gemm_run( cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream); int block = 256; - remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>( + int sfa_src_total = M * K_sf; + int sfb_src_total = N * K_sf; + remap_sf_to_cutlass_kernel<<<(sfa_src_total + block - 1) / block, block, 0, stream>>>( static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, false); - remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>( + remap_sf_to_cutlass_kernel<<<(sfb_src_total + block - 1) / block, block, 0, stream>>>( static_cast(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, true); typename Gemm::Arguments arguments { @@ -260,7 +263,8 @@ extern "C" int cutlass_nvfp4_prepack_sfb_run( cudaMemsetAsync(static_cast(SFB_cutlass_ptr), 0, sfb_size * sizeof(ElementSF), stream); int block = 256; - remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>( + int sfb_src_total = N * K_sf; + remap_sf_to_cutlass_kernel<<<(sfb_src_total + block - 1) / block, block, 0, stream>>>( static_cast(SFB_ptr), static_cast(SFB_cutlass_ptr), layout_SFB, @@ -303,7 +307,8 @@ extern "C" int cutlass_nvfp4_gemm_run_prepacked_sfb( cudaMemsetAsync(sfa_cutlass.get(), 0, sfa_size * sizeof(ElementSF), stream); int block = 256; - remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>( + int sfa_src_total = M * K_sf; + remap_sf_to_cutlass_kernel<<<(sfa_src_total + block - 1) / block, block, 0, stream>>>( static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, false); typename Gemm::Arguments arguments {