fix(tmem): TMEM ld/st are warp-collective — ALL 32 lanes must call them
Root cause of TMEM epilogue hang: tmem_store/tmem_load are warp-collective operations requiring ALL 32 lanes to participate. The loop 'for (col = lane; col < TMEM_O_COLS; col += WARP)' with TMEM_O_COLS=16 and WARP=32 means only lanes 0-15 execute the op. Lanes 16-31 skip it = warp divergence on collective = HANG. Fix: loop over TMEM_N (>= 32, power of 2) so all 32 lanes participate. Columns beyond TMEM_O_COLS write don't-care data to allocated-but-unused TMEM columns.
This commit is contained in:
@@ -219,11 +219,12 @@ fmha_decode_tmem(
|
||||
// ================================================================
|
||||
|
||||
// Step 1: Write SMEM accumulator to TMEM (warp 0, warp-collective)
|
||||
// Each TMEM column holds 4 FP32 values.
|
||||
// sPvBuf[d] for d=0..HD-1 maps to column d/4, register d%4.
|
||||
// IMPORTANT: ALL 32 lanes must call tmem_store (warp-collective).
|
||||
// With TMEM_O_COLS=16 and WARP=32, only lanes 0-15 would enter the loop.
|
||||
// Fix: loop over TMEM_N (always >= 32) so all lanes participate.
|
||||
// Lanes writing beyond TMEM_O_COLS write don't-care data to don't-care columns.
|
||||
if (wid == 0) {
|
||||
// Distribute columns across the warp's 32 lanes
|
||||
for (int col = lane; col < TMEM_O_COLS; col += WARP) {
|
||||
for (int col = lane; col < TMEM_N; col += WARP) {
|
||||
int d0 = col * 4 + 0;
|
||||
int d1 = col * 4 + 1;
|
||||
int d2 = col * 4 + 2;
|
||||
@@ -242,10 +243,11 @@ fmha_decode_tmem(
|
||||
if (tid == 0) printf("[tmem] wrote to TMEM OK\n");
|
||||
|
||||
// Step 2: Read from TMEM to registers (warp 0, warp-collective)
|
||||
// Same warp-collective constraint: ALL 32 lanes must call tmem_load.
|
||||
if (wid == 0) {
|
||||
float inv_sum = 1.0f / sRowSums[0];
|
||||
|
||||
for (int col = lane; col < TMEM_O_COLS; col += WARP) {
|
||||
for (int col = lane; col < TMEM_N; col += WARP) {
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tmem_base + col, u0, u1, u2, u3);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user