From 33cedbee0a2412e444a3e3e6fb425b4e4a3e5012 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 07:41:16 +0000 Subject: [PATCH] =?UTF-8?q?fix(tmem):=20TMEM=20ld/st=20are=20warp-collecti?= =?UTF-8?q?ve=20=E2=80=94=20ALL=2032=20lanes=20must=20call=20them?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Root cause of TMEM epilogue hang: tmem_store/tmem_load are warp-collective operations requiring ALL 32 lanes to participate. The loop 'for (col = lane; col < TMEM_O_COLS; col += WARP)' with TMEM_O_COLS=16 and WARP=32 means only lanes 0-15 execute the op. Lanes 16-31 skip it = warp divergence on collective = HANG. Fix: loop over TMEM_N (>= 32, power of 2) so all 32 lanes participate. Columns beyond TMEM_O_COLS write don't-care data to allocated-but-unused TMEM columns. --- dsv4/kernels/attention/fmha_epilogue_sm100.cuh | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh index 6a2e7db1..73d6b28d 100644 --- a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh +++ b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh @@ -219,11 +219,12 @@ fmha_decode_tmem( // ================================================================ // Step 1: Write SMEM accumulator to TMEM (warp 0, warp-collective) - // Each TMEM column holds 4 FP32 values. - // sPvBuf[d] for d=0..HD-1 maps to column d/4, register d%4. + // IMPORTANT: ALL 32 lanes must call tmem_store (warp-collective). + // With TMEM_O_COLS=16 and WARP=32, only lanes 0-15 would enter the loop. + // Fix: loop over TMEM_N (always >= 32) so all lanes participate. + // Lanes writing beyond TMEM_O_COLS write don't-care data to don't-care columns. if (wid == 0) { - // Distribute columns across the warp's 32 lanes - for (int col = lane; col < TMEM_O_COLS; col += WARP) { + for (int col = lane; col < TMEM_N; col += WARP) { int d0 = col * 4 + 0; int d1 = col * 4 + 1; int d2 = col * 4 + 2; @@ -242,10 +243,11 @@ fmha_decode_tmem( if (tid == 0) printf("[tmem] wrote to TMEM OK\n"); // Step 2: Read from TMEM to registers (warp 0, warp-collective) + // Same warp-collective constraint: ALL 32 lanes must call tmem_load. if (wid == 0) { float inv_sum = 1.0f / sRowSums[0]; - for (int col = lane; col < TMEM_O_COLS; col += WARP) { + for (int col = lane; col < TMEM_N; col += WARP) { uint32_t u0, u1, u2, u3; tmem_load(tmem_base + col, u0, u1, u2, u3);