diff --git a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh index 6a2e7db1..73d6b28d 100644 --- a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh +++ b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh @@ -219,11 +219,12 @@ fmha_decode_tmem( // ================================================================ // Step 1: Write SMEM accumulator to TMEM (warp 0, warp-collective) - // Each TMEM column holds 4 FP32 values. - // sPvBuf[d] for d=0..HD-1 maps to column d/4, register d%4. + // IMPORTANT: ALL 32 lanes must call tmem_store (warp-collective). + // With TMEM_O_COLS=16 and WARP=32, only lanes 0-15 would enter the loop. + // Fix: loop over TMEM_N (always >= 32) so all lanes participate. + // Lanes writing beyond TMEM_O_COLS write don't-care data to don't-care columns. if (wid == 0) { - // Distribute columns across the warp's 32 lanes - for (int col = lane; col < TMEM_O_COLS; col += WARP) { + for (int col = lane; col < TMEM_N; col += WARP) { int d0 = col * 4 + 0; int d1 = col * 4 + 1; int d2 = col * 4 + 2; @@ -242,10 +243,11 @@ fmha_decode_tmem( if (tid == 0) printf("[tmem] wrote to TMEM OK\n"); // Step 2: Read from TMEM to registers (warp 0, warp-collective) + // Same warp-collective constraint: ALL 32 lanes must call tmem_load. if (wid == 0) { float inv_sum = 1.0f / sRowSums[0]; - for (int col = lane; col < TMEM_O_COLS; col += WARP) { + for (int col = lane; col < TMEM_N; col += WARP) { uint32_t u0, u1, u2, u3; tmem_load(tmem_base + col, u0, u1, u2, u3);