diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 0452827d..776f153e 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -104,36 +104,27 @@ __global__ void remap_sf_to_cutlass_kernel( LayoutSF layout_sf, // CuTe layout for dst int MN, int K_sf // Source dimensions ) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = cute::size(layout_sf); - if (idx >= total) return; - - // The CUTLASS layout maps linear index -> (m, k) coordinate pair - // We need to find which (m, k) this linear index corresponds to - // and then read from our simple row-major source. + // Iterate over SOURCE indices (row-major) and write to CUTLASS destination. + // The layout maps logical (m, k) -> CUTLASS linear index. + // This is the forward direction, which CuTe handles correctly. // - // CuTe layouts support crd(idx) which gives the coordinate for an index. - // The coordinate is in the logical space of the layout. - // For SFA: the layout maps to (M, K) where K is in SF groups - // For SFB: the layout maps to (N, K) where K is in SF groups - // - // The key: the layout was created by tile_to_shape(SfAtom{}, make_shape(MN, K), Step<_2,_1>{}) - // So the coordinate (c0, c1) corresponds to row c0 and K-group c1 in the original tensor. + // Previous approach (iterate over CUTLASS idx, reverse-map with idx2crd+flatten) + // was broken: flatten() on the nested CuTe coordinate gives atom sub-indices, + // not logical (m, k). This caused K-group > 0 to always map to m*K_sf+0, + // losing all K-group information in the SFA. + int src_idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = MN * K_sf; + if (src_idx >= total) return; - // idx2crd converts a linear index to a logical coordinate in the layout's space - auto coord = cute::idx2crd(idx, layout_sf.shape(), layout_sf.stride()); + int m = src_idx / K_sf; + int k = src_idx % K_sf; - // The coordinate is a nested tuple. For a 2D layout it's (c0, c1) - // where c0 = row/M index and c1 = col/K-group index. - // Flatten the nested tuple to extract the two logical coordinates. - auto flat = cute::flatten(coord); - int m = cute::get<0>(flat); - int k = cute::get<1>(flat); + // Use the CuTe layout to find the destination index for this (m, k) + // layout_sf(m, k) returns the linear index in CUTLASS's expected layout + auto dst_idx = layout_sf(cute::make_coord(m, k)); - if (m < MN && k < K_sf) { - dst[idx] = src[m * K_sf + k]; - } else { - dst[idx] = cutlass::float_ue4m3_t(0); + if (dst_idx < cute::size(layout_sf)) { + dst[dst_idx] = src[src_idx]; } }