From 5ff1b9e4011313c6bfd802b936a404bd2d797119 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 22:11:14 +0000 Subject: [PATCH] fix: use hierarchical coordinates for layout_sf forward mapping Flat make_coord(mn, k*16) doesn't decompose into the nested atom shape. Must manually decompose: mn -> (m0, m1, mt) where m0=mn%32, m1=(mn/32)%4, mt=mn/128 k_sf -> (k0, k1, kt) where k0=0 (stride-0), k1=k_sf%4, kt=k_sf/4 --- .../cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 855810cd..566444b3 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -137,15 +137,32 @@ __global__ void remap_sf_to_cutlass_kernel( int dst_idx = 0; // Use crd2idx to map flat logical coordinate to CUTLASS physical index. - // The layout's logical shape is ((32,4,...), (SFVecSize,4,...)) for K-major, - // but CuTe accepts flat coordinates and internally maps them. - // For SFA: logical coord is (m, k_element, l) - // For SFB: logical coord is (n, k_element, l) - // k_element = k_sf * InputSFVectorSize (convert SF group to K element) + // Decompose flat (mn, k_sf) into hierarchical coordinates matching the atom layout: + // Shape: ((32, 4, mn_tiles), (SFVecSize, 4, k_tiles), ...) + // First group: mn = m0 + 32*m1 + 128*mt where m0 in [0,32), m1 in [0,4), mt = mn/128 + // Second group: k_sf = k1 + 4*kt where k1 = k_sf % 4, kt = k_sf / 4 + // (k0 is the stride-0 inner SF vector index — always 0 for the first element) + int m0 = mn % 32; + int m1 = (mn / 32) % 4; + int mt = mn / 128; + int k0 = 0; // stride-0, within SF vector + int k1 = k_sf % 4; + int kt = k_sf / 4; + + int dst_idx = 0; if constexpr (LayoutRank == 2) { - dst_idx = layout_sf(cute::make_coord(mn, k_sf * InputSFVectorSize)); + auto hier_coord = cute::make_coord( + cute::make_coord(cute::make_coord(m0, m1), mt), + cute::make_coord(cute::make_coord(k0, k1), kt) + ); + dst_idx = layout_sf(hier_coord); } else if constexpr (LayoutRank == 3) { - dst_idx = layout_sf(cute::make_coord(mn, k_sf * InputSFVectorSize, 0)); + auto hier_coord = cute::make_coord( + cute::make_coord(cute::make_coord(m0, m1), mt), + cute::make_coord(cute::make_coord(k0, k1), kt), + 0 + ); + dst_idx = layout_sf(hier_coord); } int src_idx = mn * src_stride_mn + k_sf * src_stride_ksf;