Fix B2 indexer: increase TMEM_COLS to 512 for full 128-row MMA output

The MMA produces 128 rows × 128 cols = 4 row-groups × 128 TMEM cols = 512 total.
Even though we only read rows 0-63, the MMA writes all 128 rows.
TMEM_COLS must match the MMA output size, not just the read size.
This commit is contained in:
2026-06-03 00:45:15 +00:00
parent 797345dfe9
commit d36dbba01c

View File

@@ -149,8 +149,7 @@ indexer_fp8_score_topk_kernel(
constexpr int MMA_K_F8 = 32;
constexpr int NKT = 4; // ihd=128 / MMA_K_F8=32
constexpr int TILE_F8 = 128 * 32; // 4096 bytes per SMEM tile
constexpr int TMEM_COLS = 256; // 128 rows × 128 cols needs 4×128 = 512,
// but only 64 rows used (2×128 = 256)
constexpr int TMEM_COLS = 512; // 128 rows × 128 cols → 4 row-groups × 128 cols = 512
const int tid = threadIdx.x;
const int wid = tid >> 5;