fix: remove duplicate TMEM_COLS_NEEDED declarations

This commit is contained in:
2026-05-28 07:43:54 +00:00
parent 278f1b34af
commit 579dd061cd

View File

@@ -111,7 +111,7 @@ fmha_decode_tmem(
// Initialize TMEM O to zero — warp-collective
// Use TMEM_COLS_NEEDED columns, each zeroed by all 32 lanes writing 0.
constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128;
// TMEM_COLS_NEEDED defined above
if (wid == 0) {
for (int col = 0; col < TMEM_COLS_NEEDED; col++) {
tmem_store(tmem_base + col, 0, 0, 0, 0);
@@ -233,7 +233,7 @@ fmha_decode_tmem(
// Lane i writes sPvBuf[i*4+0..3] to column 0.
// Lanes with i*4 >= HD write zeros (don't-care, but must participate).
//
constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128; // 1 for HD<=128, 2 for HD<=256
if (wid == 0) {
for (int col = 0; col < TMEM_COLS_NEEDED; col++) {
int base = col * 128; // FP32 offset for this column
@@ -256,7 +256,7 @@ fmha_decode_tmem(
// Step 2: Read from TMEM to registers (warp 0, warp-collective)
// Same lane mapping: lane i reads positions i*4+0..3 from the column.
constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128;
if (wid == 0) {
float inv_sum = 1.0f / sRowSums[0];