fix: use hierarchical coordinates for layout_sf forward mapping

Flat make_coord(mn, k*16) doesn't decompose into the nested atom shape.
Must manually decompose:
  mn -> (m0, m1, mt) where m0=mn%32, m1=(mn/32)%4, mt=mn/128
  k_sf -> (k0, k1, kt) where k0=0 (stride-0), k1=k_sf%4, kt=k_sf/4
This commit is contained in:
2026-05-15 22:11:14 +00:00
parent 3b4a7b591f
commit 5ff1b9e401

View File

@@ -137,15 +137,32 @@ __global__ void remap_sf_to_cutlass_kernel(
int dst_idx = 0;
// Use crd2idx to map flat logical coordinate to CUTLASS physical index.
// The layout's logical shape is ((32,4,...), (SFVecSize,4,...)) for K-major,
// but CuTe accepts flat coordinates and internally maps them.
// For SFA: logical coord is (m, k_element, l)
// For SFB: logical coord is (n, k_element, l)
// k_element = k_sf * InputSFVectorSize (convert SF group to K element)
// Decompose flat (mn, k_sf) into hierarchical coordinates matching the atom layout:
// Shape: ((32, 4, mn_tiles), (SFVecSize, 4, k_tiles), ...)
// First group: mn = m0 + 32*m1 + 128*mt where m0 in [0,32), m1 in [0,4), mt = mn/128
// Second group: k_sf = k1 + 4*kt where k1 = k_sf % 4, kt = k_sf / 4
// (k0 is the stride-0 inner SF vector index — always 0 for the first element)
int m0 = mn % 32;
int m1 = (mn / 32) % 4;
int mt = mn / 128;
int k0 = 0; // stride-0, within SF vector
int k1 = k_sf % 4;
int kt = k_sf / 4;
int dst_idx = 0;
if constexpr (LayoutRank == 2) {
dst_idx = layout_sf(cute::make_coord(mn, k_sf * InputSFVectorSize));
auto hier_coord = cute::make_coord(
cute::make_coord(cute::make_coord(m0, m1), mt),
cute::make_coord(cute::make_coord(k0, k1), kt)
);
dst_idx = layout_sf(hier_coord);
} else if constexpr (LayoutRank == 3) {
dst_idx = layout_sf(cute::make_coord(mn, k_sf * InputSFVectorSize, 0));
auto hier_coord = cute::make_coord(
cute::make_coord(cute::make_coord(m0, m1), mt),
cute::make_coord(cute::make_coord(k0, k1), kt),
0
);
dst_idx = layout_sf(hier_coord);
}
int src_idx = mn * src_stride_mn + k_sf * src_stride_ksf;