debug: handle flat_rank=8 for SF remap, add coordinate dump

Previous approach assumed rank 2-6, but actual rank is 8.
For R==8: 4 M sub-indices (inner_32, inner_4, tile_interleave, tile_m)
          4 K sub-indices (inner_16, inner_4_k, tile_k_interleave, tile_k)
m = (f3*2 + f2)*128 + f0*4 + f1
k_sf = f5 + f6*4  (tentative, needs printf verification)
Added printf of all 8 flat values for first 3 indices.
This commit is contained in:
2026-05-14 15:45:52 +00:00
parent d2c1c76f5b
commit 8ee3f90e44

View File

@@ -131,46 +131,63 @@ __global__ void remap_sf_to_cutlass_kernel(
printf("[remap] flat_rank=%d MN=%d K_sf=%d total=%d\n", R, MN, K_sf, total);
}
// Debug: print first few coordinates
if (dst_idx < 3) {
printf("[remap] idx=%d flat_rank=%d vals=", dst_idx, R);
if constexpr (R >= 8) {
printf("%d,%d,%d,%d,%d,%d,%d,%d",
int(cute::get<0>(flat)), int(cute::get<1>(flat)),
int(cute::get<2>(flat)), int(cute::get<3>(flat)),
int(cute::get<4>(flat)), int(cute::get<5>(flat)),
int(cute::get<6>(flat)), int(cute::get<7>(flat)));
}
printf("\n");
}
int m = 0, k_sf = 0;
if constexpr (R == 6) {
// Full 6: (inner_m, sub_m, tile_m, inner_k, sub_k, tile_k)
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int tile_m = cute::get<2>(flat);
// get<3> = inner_k (within SF group, ignored for k_sf)
int sub_k = cute::get<4>(flat);
int tile_k = cute::get<5>(flat);
m = tile_m * 128 + inner_m * 4 + sub_m;
k_sf = tile_k * 4 + sub_k;
} else if constexpr (R == 5) {
// 5: maybe (inner_m, sub_m, tile_m, sub_k, tile_k) or similar
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int tile_m = cute::get<2>(flat);
int sub_k = cute::get<3>(flat);
int tile_k = cute::get<4>(flat);
m = tile_m * 128 + inner_m * 4 + sub_m;
k_sf = tile_k * 4 + sub_k;
} else if constexpr (R == 4) {
// 4: (inner_m, sub_m, sub_k, tile_k) — small M fits in 1 tile
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int sub_k = cute::get<2>(flat);
int tile_k = cute::get<3>(flat);
m = inner_m * 4 + sub_m;
k_sf = tile_k * 4 + sub_k;
} else if constexpr (R == 3) {
// 3: maybe (inner_m, sub_m, k_combined)
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int k_comb = cute::get<2>(flat);
m = inner_m * 4 + sub_m;
k_sf = k_comb;
} else if constexpr (R == 2) {
// 2: flat (m, k) — no nesting
m = cute::get<0>(flat);
k_sf = cute::get<1>(flat);
if constexpr (R == 8) {
// 8 flattened coordinates: 4 for M, 4 for K
// M: (inner_32, inner_4, tile_interleave, tile_m)
// K: (inner_16, inner_4, tile_interleave_k, tile_k)
//
// SfAtom shape (32,4) with stride (16,4):
// inner_32 * 16 + inner_4 * 4 = local offset within atom
// m = inner_32 * 4 + inner_4 + (tile_interleave + tile_m * 2) * 128
// (step=2 means M tiles interleave, so tile_interleave is 0 or 1)
// k_sf = inner_16 is within one SF group (0..15), ignored
// inner_4 is the sub-K (0..3)
// k_sf = inner_4 + tile_interleave_k * 4 + tile_k * 4
// Actually inner_16 covers 16 elements = 1 SF group, so it doesn't
// contribute to k_sf. k_sf = sub_k + (tile_k * 4)
// But wait, with the tile_interleave_k, it might be:
// k_sf = inner_4 + tile_interleave_k * 4 + tile_k * 8
//
// Let me just compute m and k_sf empirically from the flat values:
int f0 = cute::get<0>(flat);
int f1 = cute::get<1>(flat);
int f2 = cute::get<2>(flat);
int f3 = cute::get<3>(flat);
int f4 = cute::get<4>(flat);
int f5 = cute::get<5>(flat);
int f6 = cute::get<6>(flat);
int f7 = cute::get<7>(flat);
// M = f0..f3, K = f4..f7
// f0 = inner_32 (0..31), f1 = inner_4 (0..3)
// These two give local_m = f0 * 4 + f1 (0..127)
// f2, f3 are M tiling: with Step<2>, f2 is interleave (0..1), f3 is tile (0..)
// m = (f3 * 2 + f2) * 128 + f0 * 4 + f1
m = (f3 * 2 + f2) * 128 + f0 * 4 + f1;
// f4 = inner_16 (0..15, within one SF group, doesn't contribute to k_sf)
// f5 = inner_4_k (0..3, K sub-index within atom)
// f6, f7 are K tiling
// k_sf = f5 + f6 * 4 + f7 * 8 (guessing the tiling pattern)
// Actually with Step<1> on K, the tiling is simpler:
// k_sf = f5 + (f6 + f7 * n_k_interleave) * 4
// Let me try: k_sf = f5 + f6 * 4 (assuming f7 is outer tile)
k_sf = f5 + f6 * 4; // This may need adjustment based on printf output
} else {
// Fallback: index 0 and 1
m = 0; k_sf = 0;