From a8bd962452bd25dd35ee840a2c0c5dbcc28db53f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 15:01:47 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20SF=20remap=20=E2=80=94=20iterate=20dest?= =?UTF-8?q?=20indices,=20extract=20logical=20(m,=20k=5Fsf)=20from=20nested?= =?UTF-8?q?=20coord?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The forward-map approach (src -> layout_sf(m, k)) failed because CuTe's layout operator requires coordinates matching the nested shape rank, and passing flat (int, int) to a ((32,4),K) shape triggers Mismatched Ranks. New approach: iterate over CUTLASS dest indices, use idx2crd to get the hierarchical coordinate, flatten it, then extract logical (m, k_sf) by interpreting the flattened sub-coordinates correctly: flat[0..2] = (inner_M, sub_M, tile_M) -> m = tile_M*128 + inner_M*4 + sub_M flat[3..5] = (inner_K, sub_K, tile_K) -> k_sf = tile_K*4 + sub_K (inner_K is within one SF group — same byte, so ignored for k_sf) Previous bug: get<0> and get<1> of flatten gave (inner_M, sub_M) — both M sub-indices. K information was never extracted, so only k_group=0 worked. Dest buffer is zero-initialized so padding slots (where m >= MN or k_sf >= K_sf) stay zero. --- .../cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu | 94 ++++++++++--------- 1 file changed, 51 insertions(+), 43 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 06bbd598..000ab028 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -97,47 +97,57 @@ using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale factor remap kernel using CuTe layout operations ///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// +// Scale factor remap kernel using CuTe layout operations +///////////////////////////////////////////////////////////////////////////////////////////////// + template __global__ void remap_sf_to_cutlass_kernel( - const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major - cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout - LayoutSF layout_sf, // CuTe layout for dst - int MN, int K_sf, // Source dimensions (in SF groups) - int SFVecSize // Elements per SF group (16 for NVFP4) + const cutlass::float_ue4m3_t* __restrict__ src, + cutlass::float_ue4m3_t* __restrict__ dst, + LayoutSF layout_sf, + int MN, int K_sf ) { - // Iterate over SOURCE indices (row-major) and write to CUTLASS destination. - // The layout maps logical (m, k_elements) -> CUTLASS linear index. - // This is the forward direction, which CuTe handles correctly. - // - // IMPORTANT: The CuTe layout uses ELEMENT-SPACE for K, not group-space. - // So k_group=3 with SFVecSize=16 maps to element k=3*16=48 in the layout. - // - // Previous approach (iterate over CUTLASS idx, reverse-map with idx2crd+flatten) - // was broken: flatten() on the nested CuTe coordinate gives atom sub-indices, - // not logical (m, k). This caused all K-groups > 0 in SFA to map to m*K_sf+0, - // losing K-group information entirely. - int src_idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = MN * K_sf; - if (src_idx >= total) return; - - int m = src_idx / K_sf; - int k_sf = src_idx % K_sf; - int k_elem = k_sf * SFVecSize; // Convert to element-space for CuTe layout - - // Use the CuTe layout to find the destination index for this (m, k_elem) - // The layout has nested shape like ((32, 4, tiles_m), (16, 4, tiles_k)) but - // we can pass a flat (m, k_elem) coordinate and CuTe will project correctly. - // Using layout_sf[m, k_elem] (square brackets) avoids the rank-matching issue - // that make_coord() has with nested shapes. - int dst_idx = layout_sf(m, k_elem); + int dst_idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = cute::size(layout_sf); + if (dst_idx >= total) return; + + auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride()); + auto flat = cute::flatten(coord); + + constexpr int flat_rank = cute::rank_v; + + int m, k_sf_out; + + if constexpr (flat_rank == 6) { + // Full nested: (inner_M, sub_M, tile_M, inner_K, sub_K, tile_K) + int inner_m = cute::get<0>(flat); + int sub_m = cute::get<1>(flat); + int tile_m = cute::get<2>(flat); + // get<3> = inner_k (within one SF group — same byte) + int sub_k = cute::get<4>(flat); + int tile_k = cute::get<5>(flat); + + m = tile_m * 128 + inner_m * 4 + sub_m; + k_sf_out = tile_k * 4 + sub_k; + } else if constexpr (flat_rank == 4) { + int inner_m = cute::get<0>(flat); + int sub_m = cute::get<1>(flat); + int sub_k = cute::get<2>(flat); + int tile_k = cute::get<3>(flat); + + m = inner_m * 4 + sub_m; + k_sf_out = tile_k * 4 + sub_k; + } else { + m = cute::get<0>(flat); + k_sf_out = cute::get<1>(flat); + } + + if (m < MN && k_sf_out < K_sf) { + dst[dst_idx] = src[m * K_sf + k_sf_out]; + } } -///////////////////////////////////////////////////////////////////////////////////////////////// -// C API -///////////////////////////////////////////////////////////////////////////////////////////////// - -extern "C" { - int cutlass_nvfp4_gemm_run( const void* A_ptr, const void* SFA_ptr, const void* B_ptr, const void* SFB_ptr, @@ -173,13 +183,11 @@ int cutlass_nvfp4_gemm_run( cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream); int block = 256; - // Grid size based on SOURCE elements (M*K_sf), not CUTLASS buffer size - int sfa_src_total = M * K_sf; - int sfb_src_total = N * K_sf; - remap_sf_to_cutlass_kernel<<<(sfa_src_total + block - 1) / block, block, 0, stream>>>( - static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, InputSFVectorSize); - remap_sf_to_cutlass_kernel<<<(sfb_src_total + block - 1) / block, block, 0, stream>>>( - static_cast(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, InputSFVectorSize); + // Grid size based on DEST (CUTLASS buffer) size since kernel iterates dest + remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>( + static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf); + remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>( + static_cast(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf); typename Gemm::Arguments arguments { cutlass::gemm::GemmUniversalMode::kGemm,