Fix B2 indexer: increase TMEM_COLS to 512 for full 128-row MMA output
The MMA produces 128 rows × 128 cols = 4 row-groups × 128 TMEM cols = 512 total. Even though we only read rows 0-63, the MMA writes all 128 rows. TMEM_COLS must match the MMA output size, not just the read size.
This commit is contained in:
@@ -149,8 +149,7 @@ indexer_fp8_score_topk_kernel(
|
||||
constexpr int MMA_K_F8 = 32;
|
||||
constexpr int NKT = 4; // ihd=128 / MMA_K_F8=32
|
||||
constexpr int TILE_F8 = 128 * 32; // 4096 bytes per SMEM tile
|
||||
constexpr int TMEM_COLS = 256; // 128 rows × 128 cols needs 4×128 = 512,
|
||||
// but only 64 rows used (2×128 = 256)
|
||||
constexpr int TMEM_COLS = 512; // 128 rows × 128 cols → 4 row-groups × 128 cols = 512
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid >> 5;
|
||||
|
||||
Reference in New Issue
Block a user