From 579dd061cd8ae01fa5b2b76952dc25a93c744326 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 07:43:54 +0000 Subject: [PATCH] fix: remove duplicate TMEM_COLS_NEEDED declarations --- dsv4/kernels/attention/fmha_epilogue_sm100.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh index 181eacbc..0c04d696 100644 --- a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh +++ b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh @@ -111,7 +111,7 @@ fmha_decode_tmem( // Initialize TMEM O to zero — warp-collective // Use TMEM_COLS_NEEDED columns, each zeroed by all 32 lanes writing 0. - constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128; + // TMEM_COLS_NEEDED defined above if (wid == 0) { for (int col = 0; col < TMEM_COLS_NEEDED; col++) { tmem_store(tmem_base + col, 0, 0, 0, 0); @@ -233,7 +233,7 @@ fmha_decode_tmem( // Lane i writes sPvBuf[i*4+0..3] to column 0. // Lanes with i*4 >= HD write zeros (don't-care, but must participate). // - constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128; // 1 for HD<=128, 2 for HD<=256 + if (wid == 0) { for (int col = 0; col < TMEM_COLS_NEEDED; col++) { int base = col * 128; // FP32 offset for this column @@ -256,7 +256,7 @@ fmha_decode_tmem( // Step 2: Read from TMEM to registers (warp 0, warp-collective) // Same lane mapping: lane i reads positions i*4+0..3 from the column. - constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128; + if (wid == 0) { float inv_sum = 1.0f / sRowSums[0];