debug: handle flat_rank=8 for SF remap, add coordinate dump
Previous approach assumed rank 2-6, but actual rank is 8.
For R==8: 4 M sub-indices (inner_32, inner_4, tile_interleave, tile_m)
4 K sub-indices (inner_16, inner_4_k, tile_k_interleave, tile_k)
m = (f3*2 + f2)*128 + f0*4 + f1
k_sf = f5 + f6*4 (tentative, needs printf verification)
Added printf of all 8 flat values for first 3 indices.
This commit is contained in:
@@ -131,46 +131,63 @@ __global__ void remap_sf_to_cutlass_kernel(
|
||||
printf("[remap] flat_rank=%d MN=%d K_sf=%d total=%d\n", R, MN, K_sf, total);
|
||||
}
|
||||
|
||||
// Debug: print first few coordinates
|
||||
if (dst_idx < 3) {
|
||||
printf("[remap] idx=%d flat_rank=%d vals=", dst_idx, R);
|
||||
if constexpr (R >= 8) {
|
||||
printf("%d,%d,%d,%d,%d,%d,%d,%d",
|
||||
int(cute::get<0>(flat)), int(cute::get<1>(flat)),
|
||||
int(cute::get<2>(flat)), int(cute::get<3>(flat)),
|
||||
int(cute::get<4>(flat)), int(cute::get<5>(flat)),
|
||||
int(cute::get<6>(flat)), int(cute::get<7>(flat)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
int m = 0, k_sf = 0;
|
||||
|
||||
if constexpr (R == 6) {
|
||||
// Full 6: (inner_m, sub_m, tile_m, inner_k, sub_k, tile_k)
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int tile_m = cute::get<2>(flat);
|
||||
// get<3> = inner_k (within SF group, ignored for k_sf)
|
||||
int sub_k = cute::get<4>(flat);
|
||||
int tile_k = cute::get<5>(flat);
|
||||
m = tile_m * 128 + inner_m * 4 + sub_m;
|
||||
k_sf = tile_k * 4 + sub_k;
|
||||
} else if constexpr (R == 5) {
|
||||
// 5: maybe (inner_m, sub_m, tile_m, sub_k, tile_k) or similar
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int tile_m = cute::get<2>(flat);
|
||||
int sub_k = cute::get<3>(flat);
|
||||
int tile_k = cute::get<4>(flat);
|
||||
m = tile_m * 128 + inner_m * 4 + sub_m;
|
||||
k_sf = tile_k * 4 + sub_k;
|
||||
} else if constexpr (R == 4) {
|
||||
// 4: (inner_m, sub_m, sub_k, tile_k) — small M fits in 1 tile
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int sub_k = cute::get<2>(flat);
|
||||
int tile_k = cute::get<3>(flat);
|
||||
m = inner_m * 4 + sub_m;
|
||||
k_sf = tile_k * 4 + sub_k;
|
||||
} else if constexpr (R == 3) {
|
||||
// 3: maybe (inner_m, sub_m, k_combined)
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int k_comb = cute::get<2>(flat);
|
||||
m = inner_m * 4 + sub_m;
|
||||
k_sf = k_comb;
|
||||
} else if constexpr (R == 2) {
|
||||
// 2: flat (m, k) — no nesting
|
||||
m = cute::get<0>(flat);
|
||||
k_sf = cute::get<1>(flat);
|
||||
if constexpr (R == 8) {
|
||||
// 8 flattened coordinates: 4 for M, 4 for K
|
||||
// M: (inner_32, inner_4, tile_interleave, tile_m)
|
||||
// K: (inner_16, inner_4, tile_interleave_k, tile_k)
|
||||
//
|
||||
// SfAtom shape (32,4) with stride (16,4):
|
||||
// inner_32 * 16 + inner_4 * 4 = local offset within atom
|
||||
// m = inner_32 * 4 + inner_4 + (tile_interleave + tile_m * 2) * 128
|
||||
// (step=2 means M tiles interleave, so tile_interleave is 0 or 1)
|
||||
// k_sf = inner_16 is within one SF group (0..15), ignored
|
||||
// inner_4 is the sub-K (0..3)
|
||||
// k_sf = inner_4 + tile_interleave_k * 4 + tile_k * 4
|
||||
// Actually inner_16 covers 16 elements = 1 SF group, so it doesn't
|
||||
// contribute to k_sf. k_sf = sub_k + (tile_k * 4)
|
||||
// But wait, with the tile_interleave_k, it might be:
|
||||
// k_sf = inner_4 + tile_interleave_k * 4 + tile_k * 8
|
||||
//
|
||||
// Let me just compute m and k_sf empirically from the flat values:
|
||||
int f0 = cute::get<0>(flat);
|
||||
int f1 = cute::get<1>(flat);
|
||||
int f2 = cute::get<2>(flat);
|
||||
int f3 = cute::get<3>(flat);
|
||||
int f4 = cute::get<4>(flat);
|
||||
int f5 = cute::get<5>(flat);
|
||||
int f6 = cute::get<6>(flat);
|
||||
int f7 = cute::get<7>(flat);
|
||||
|
||||
// M = f0..f3, K = f4..f7
|
||||
// f0 = inner_32 (0..31), f1 = inner_4 (0..3)
|
||||
// These two give local_m = f0 * 4 + f1 (0..127)
|
||||
// f2, f3 are M tiling: with Step<2>, f2 is interleave (0..1), f3 is tile (0..)
|
||||
// m = (f3 * 2 + f2) * 128 + f0 * 4 + f1
|
||||
m = (f3 * 2 + f2) * 128 + f0 * 4 + f1;
|
||||
|
||||
// f4 = inner_16 (0..15, within one SF group, doesn't contribute to k_sf)
|
||||
// f5 = inner_4_k (0..3, K sub-index within atom)
|
||||
// f6, f7 are K tiling
|
||||
// k_sf = f5 + f6 * 4 + f7 * 8 (guessing the tiling pattern)
|
||||
// Actually with Step<1> on K, the tiling is simpler:
|
||||
// k_sf = f5 + (f6 + f7 * n_k_interleave) * 4
|
||||
// Let me try: k_sf = f5 + f6 * 4 (assuming f7 is outer tile)
|
||||
k_sf = f5 + f6 * 4; // This may need adjustment based on printf output
|
||||
} else {
|
||||
// Fallback: index 0 and 1
|
||||
m = 0; k_sf = 0;
|
||||
|
||||
Reference in New Issue
Block a user