fix: use crd2idx instead of layout operator() for SF forward mapping

This commit is contained in:
2026-05-15 21:52:02 +00:00
parent a09d8e477e
commit 59dad8e2fb

View File

@@ -137,10 +137,18 @@ __global__ void remap_sf_to_cutlass_kernel(
constexpr int LayoutRank = cute::rank_v<decltype(layout_sf.shape())>;
int dst_idx = 0;
// Use crd2idx to map flat logical coordinate to CUTLASS physical index.
// The layout's logical shape is ((32,4,...), (SFVecSize,4,...)) for K-major,
// but CuTe accepts flat coordinates and internally maps them.
// For SFA: logical coord is (m, k_element, l)
// For SFB: logical coord is (n, k_element, l)
// k_element = k_sf * InputSFVectorSize (convert SF group to K element)
if constexpr (LayoutRank == 2) {
dst_idx = layout_sf(cute::make_coord(mn, k_sf * InputSFVectorSize));
auto logical_coord = cute::make_coord(mn, k_sf * InputSFVectorSize);
dst_idx = cute::crd2idx(logical_coord, layout_sf);
} else if constexpr (LayoutRank == 3) {
dst_idx = layout_sf(cute::make_coord(mn, k_sf * InputSFVectorSize, 0));
auto logical_coord = cute::make_coord(mn, k_sf * InputSFVectorSize, 0);
dst_idx = cute::crd2idx(logical_coord, layout_sf);
}
int src_idx = mn * src_stride_mn + k_sf * src_stride_ksf;