diff --git a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh index 181eacbc..0c04d696 100644 --- a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh +++ b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh @@ -111,7 +111,7 @@ fmha_decode_tmem( // Initialize TMEM O to zero — warp-collective // Use TMEM_COLS_NEEDED columns, each zeroed by all 32 lanes writing 0. - constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128; + // TMEM_COLS_NEEDED defined above if (wid == 0) { for (int col = 0; col < TMEM_COLS_NEEDED; col++) { tmem_store(tmem_base + col, 0, 0, 0, 0); @@ -233,7 +233,7 @@ fmha_decode_tmem( // Lane i writes sPvBuf[i*4+0..3] to column 0. // Lanes with i*4 >= HD write zeros (don't-care, but must participate). // - constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128; // 1 for HD<=128, 2 for HD<=256 + if (wid == 0) { for (int col = 0; col < TMEM_COLS_NEEDED; col++) { int base = col * 128; // FP32 offset for this column @@ -256,7 +256,7 @@ fmha_decode_tmem( // Step 2: Read from TMEM to registers (warp 0, warp-collective) // Same lane mapping: lane i reads positions i*4+0..3 from the column. - constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128; + if (wid == 0) { float inv_sum = 1.0f / sRowSums[0];