diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index bbed2662..93806034 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -94,16 +94,20 @@ using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; ///////////////////////////////////////////////////////////////////////////////////////////////// -// Scale factor remap kernel: row-major source -> CUTLASS interleaved layout +// Scale factor remap: row-major source -> CUTLASS interleaved layout // -// Iterates over source (m, k_group) indices and uses layout_sf(m, k_elem) to -// compute the CUTLASS destination offset. CuTe's layout operator() decomposes -// flat coordinates into nested sub-coordinates automatically — no idx2crd, -// no flatten, no rank inspection. +// Iterates over CUTLASS dest indices, uses idx2crd to get the hierarchical coordinate, +// then extracts logical (m, k_sf) from the flattened result. // -// K is in element-space for the layout: k_group * SFVecSize. -// Iterating groups (not every element) is efficient since all 16 elements -// within a group share one SF byte. +// The key challenge: the flattened coordinate from idx2crd has nested structure. +// For SFA with Step<_2,_1> tiling, the layout shape is: +// ((32, 4, n_m_tiles), (16, 4, n_k_tiles)) +// Flattening gives: (inner_m, sub_m, tile_m, inner_k, sub_k, tile_k) +// where inner_m in [0,32), sub_m in [0,4), tile_m in [0, n_m_tiles) +// inner_k in [0,16) (within one SF group), sub_k in [0,4), tile_k in [0, n_k_tiles) +// +// Logical m = tile_m * 128 + inner_m * 4 + sub_m +// Logical k_sf = tile_k * 4 + sub_k (inner_k is within one SF group — same byte) ///////////////////////////////////////////////////////////////////////////////////////////////// template @@ -111,26 +115,75 @@ __global__ void remap_sf_to_cutlass_kernel( const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout (zero-initialized) LayoutSF layout_sf, // CuTe layout for dst - int MN, int K_sf, // Source dimensions (in SF groups) - int SFVecSize // Elements per SF group (16 for NVFP4) + int MN, int K_sf // Source dimensions (in SF groups) ) { - int src_idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = MN * K_sf; - if (src_idx >= total) return; + int dst_idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = cute::size(layout_sf); + if (dst_idx >= total) return; - int m = src_idx / K_sf; - int k_group = src_idx % K_sf; - int k_elem = k_group * SFVecSize; + auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride()); + auto flat = cute::flatten(coord); - // Construct a coordinate matching the layout's top-level mode structure. - // The layout shape from tile_to_shape(SfAtom{}, make_shape(MN, K), Step<_2, _1>) - // has two top-level modes: - // Mode 0 (M): shape (32, 4) — the SfAtom's M sub-structure - // Mode 1 (K): int — the K dimension - // For mode 0, we decompose m into (inner_m, sub_m) = (m / 4, m % 4) - // For mode 1, we pass k_elem directly. - auto coord = cute::make_tuple(cute::make_tuple(m / 4, m % 4), k_elem); - dst[layout_sf(coord)] = src[src_idx]; + constexpr int R = cute::rank_v; + + // Debug: print rank once (only thread 0) + if (dst_idx == 0) { + printf("[remap] flat_rank=%d MN=%d K_sf=%d total=%d\n", R, MN, K_sf, total); + } + + int m = 0, k_sf = 0; + + if constexpr (R == 6) { + // Full 6: (inner_m, sub_m, tile_m, inner_k, sub_k, tile_k) + int inner_m = cute::get<0>(flat); + int sub_m = cute::get<1>(flat); + int tile_m = cute::get<2>(flat); + // get<3> = inner_k (within SF group, ignored for k_sf) + int sub_k = cute::get<4>(flat); + int tile_k = cute::get<5>(flat); + m = tile_m * 128 + inner_m * 4 + sub_m; + k_sf = tile_k * 4 + sub_k; + } else if constexpr (R == 5) { + // 5: maybe (inner_m, sub_m, tile_m, sub_k, tile_k) or similar + int inner_m = cute::get<0>(flat); + int sub_m = cute::get<1>(flat); + int tile_m = cute::get<2>(flat); + int sub_k = cute::get<3>(flat); + int tile_k = cute::get<4>(flat); + m = tile_m * 128 + inner_m * 4 + sub_m; + k_sf = tile_k * 4 + sub_k; + } else if constexpr (R == 4) { + // 4: (inner_m, sub_m, sub_k, tile_k) — small M fits in 1 tile + int inner_m = cute::get<0>(flat); + int sub_m = cute::get<1>(flat); + int sub_k = cute::get<2>(flat); + int tile_k = cute::get<3>(flat); + m = inner_m * 4 + sub_m; + k_sf = tile_k * 4 + sub_k; + } else if constexpr (R == 3) { + // 3: maybe (inner_m, sub_m, k_combined) + int inner_m = cute::get<0>(flat); + int sub_m = cute::get<1>(flat); + int k_comb = cute::get<2>(flat); + m = inner_m * 4 + sub_m; + k_sf = k_comb; + } else if constexpr (R == 2) { + // 2: flat (m, k) — no nesting + m = cute::get<0>(flat); + k_sf = cute::get<1>(flat); + } else { + // Fallback: index 0 and 1 + m = 0; k_sf = 0; + if (dst_idx == 0) printf("[remap] UNEXPECTED flat_rank=%d\n", R); + } + + if (dst_idx == 0) { + printf("[remap] first coord: m=%d k_sf=%d (src_idx would be %d)\n", m, k_sf, m * K_sf + k_sf); + } + + if (m < MN && k_sf < K_sf) { + dst[dst_idx] = src[m * K_sf + k_sf]; + } } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -164,19 +217,16 @@ int cutlass_nvfp4_gemm_run( int sfb_size = cute::size(layout_SFB); int K_sf = K / InputSFVectorSize; - // Allocate CUTLASS-layout buffers, zero-init, and remap scales cutlass::device_memory::allocation sfa_cutlass(sfa_size); cutlass::device_memory::allocation sfb_cutlass(sfb_size); cudaMemsetAsync(sfa_cutlass.get(), 0, sfa_size * sizeof(ElementSF), stream); cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream); int block = 256; - int sfa_src = M * K_sf; - int sfb_src = N * K_sf; - remap_sf_to_cutlass_kernel<<<(sfa_src + block - 1) / block, block, 0, stream>>>( - static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, InputSFVectorSize); - remap_sf_to_cutlass_kernel<<<(sfb_src + block - 1) / block, block, 0, stream>>>( - static_cast(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, InputSFVectorSize); + remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>( + static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf); + remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>( + static_cast(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf); typename Gemm::Arguments arguments { cutlass::gemm::GemmUniversalMode::kGemm,