fix(tmem): TMEM ld/st are warp-collective — ALL 32 lanes must call them

Root cause of TMEM epilogue hang: tmem_store/tmem_load are
warp-collective operations requiring ALL 32 lanes to participate.

The loop 'for (col = lane; col < TMEM_O_COLS; col += WARP)' with
TMEM_O_COLS=16 and WARP=32 means only lanes 0-15 execute the op.
Lanes 16-31 skip it = warp divergence on collective = HANG.

Fix: loop over TMEM_N (>= 32, power of 2) so all 32 lanes
participate. Columns beyond TMEM_O_COLS write don't-care data
to allocated-but-unused TMEM columns.
This commit is contained in:
2026-05-28 07:41:16 +00:00
parent cea02fe407
commit 33cedbee0a

View File

@@ -219,11 +219,12 @@ fmha_decode_tmem(
// ================================================================
// Step 1: Write SMEM accumulator to TMEM (warp 0, warp-collective)
// Each TMEM column holds 4 FP32 values.
// sPvBuf[d] for d=0..HD-1 maps to column d/4, register d%4.
// IMPORTANT: ALL 32 lanes must call tmem_store (warp-collective).
// With TMEM_O_COLS=16 and WARP=32, only lanes 0-15 would enter the loop.
// Fix: loop over TMEM_N (always >= 32) so all lanes participate.
// Lanes writing beyond TMEM_O_COLS write don't-care data to don't-care columns.
if (wid == 0) {
// Distribute columns across the warp's 32 lanes
for (int col = lane; col < TMEM_O_COLS; col += WARP) {
for (int col = lane; col < TMEM_N; col += WARP) {
int d0 = col * 4 + 0;
int d1 = col * 4 + 1;
int d2 = col * 4 + 2;
@@ -242,10 +243,11 @@ fmha_decode_tmem(
if (tid == 0) printf("[tmem] wrote to TMEM OK\n");
// Step 2: Read from TMEM to registers (warp 0, warp-collective)
// Same warp-collective constraint: ALL 32 lanes must call tmem_load.
if (wid == 0) {
float inv_sum = 1.0f / sRowSums[0];
for (int col = lane; col < TMEM_O_COLS; col += WARP) {
for (int col = lane; col < TMEM_N; col += WARP) {
uint32_t u0, u1, u2, u3;
tmem_load(tmem_base + col, u0, u1, u2, u3);