fix: use crd2idx instead of layout operator() for SF forward mapping
This commit is contained in:
@@ -137,10 +137,18 @@ __global__ void remap_sf_to_cutlass_kernel(
|
||||
constexpr int LayoutRank = cute::rank_v<decltype(layout_sf.shape())>;
|
||||
|
||||
int dst_idx = 0;
|
||||
// Use crd2idx to map flat logical coordinate to CUTLASS physical index.
|
||||
// The layout's logical shape is ((32,4,...), (SFVecSize,4,...)) for K-major,
|
||||
// but CuTe accepts flat coordinates and internally maps them.
|
||||
// For SFA: logical coord is (m, k_element, l)
|
||||
// For SFB: logical coord is (n, k_element, l)
|
||||
// k_element = k_sf * InputSFVectorSize (convert SF group to K element)
|
||||
if constexpr (LayoutRank == 2) {
|
||||
dst_idx = layout_sf(cute::make_coord(mn, k_sf * InputSFVectorSize));
|
||||
auto logical_coord = cute::make_coord(mn, k_sf * InputSFVectorSize);
|
||||
dst_idx = cute::crd2idx(logical_coord, layout_sf);
|
||||
} else if constexpr (LayoutRank == 3) {
|
||||
dst_idx = layout_sf(cute::make_coord(mn, k_sf * InputSFVectorSize, 0));
|
||||
auto logical_coord = cute::make_coord(mn, k_sf * InputSFVectorSize, 0);
|
||||
dst_idx = cute::crd2idx(logical_coord, layout_sf);
|
||||
}
|
||||
|
||||
int src_idx = mn * src_stride_mn + k_sf * src_stride_ksf;
|
||||
|
||||
Reference in New Issue
Block a user