fix: divide K element index by SFVecSize to get k_sf

Based on veitner bearblog analysis of CUTLASS SF layout:
- Shape is ((32,4,K_tiles), (SFVecSize,4,M_tiles)) for SFA
- get<0..2> covers K dimension, get<3..5> covers M dimension
- k_sf = K_element_index / SFVecSize
This commit is contained in:
2026-05-15 20:17:24 +00:00
parent a09b9b53a3
commit ff5a0843dc

View File

@@ -136,13 +136,18 @@ __global__ void remap_sf_to_cutlass_kernel(
int m = 0, k_sf = 0;
if constexpr (R == 6) {
// K-major SfAtom: ((32, 4, k_tiles), (SFVecSize, 4, m_tiles))
// Flattened: (inner_k, sub_k, tile_k, inner_m, sub_m, tile_m)
k_sf = cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128;
// K-major SfAtom tiled with Step<(2,1,3)>:
// Shape: ((32, 4, K_tiles), (SFVecSize, 4, M_tiles))
// First group covers K dimension, second covers M dimension
// get<0..2> = K group, get<3..5> = M group
// K element index = get<0> + get<1>*32 + get<2>*128
// k_sf = K_element_index / SFVecSize
// M element index = get<3> + get<4>*SFVecSize + get<5>*(SFVecSize*4)
k_sf = (cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128) / InputSFVectorSize;
m = cute::get<3>(flat) + cute::get<4>(flat) * InputSFVectorSize + cute::get<5>(flat) * (InputSFVectorSize * 4);
} else if constexpr (R == 8) {
// With batch dimension: ((32, 4, k_tiles), (SFVecSize, 4, m_tiles), (1,))
k_sf = cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128;
// With batch/L dimension: ((32, 4, K_tiles), (SFVecSize, 4, M_tiles), (1, 1, L))
k_sf = (cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128) / InputSFVectorSize;
m = cute::get<3>(flat) + cute::get<4>(flat) * InputSFVectorSize + cute::get<5>(flat) * (InputSFVectorSize * 4);
} else {
m = 0; k_sf = 0;