refactor(tmem): clean rewrite of TMEM epilogue kernel

Removed all dead code from the first (broken) attention loop approach.
Clean pipeline: SMEM attention → TMEM write → TMEM read → normalize → GMEM.

Also renamed sPvBuf to sO for clarity (same as reference kernel).
This commit is contained in:
2026-05-28 07:49:03 +00:00
parent c989dc78d9
commit 446a0ca9fd

View File

@@ -1,52 +1,20 @@
/**
* DSV4 FMHA Phase 2 — TMEM accumulator + one-way correction epilogue.
*
* ==================================================================
* STATUS: FIXING — TMEM ops must be warp-collective
* ==================================================================
* STATUS: WORKING — TMEM pipeline functional (SMEM → TMEM → regs → normalize → GMEM)
*
* The root cause of the hang was identified:
* This kernel proves the MoE-style one-way correction epilogue works for FMHA:
* 1. Compute attention in SMEM (same as reference)
* 2. Write accumulator to TMEM (warp-collective store)
* 3. Read from TMEM to registers (warp-collective load)
* 4. Normalize in registers (per-lane math)
* 5. Cast to BF16 and write to GMEM
*
* 1. tcgen05.ld and tcgen05.st are WARP-COLLECTIVE operations. ALL 32 lanes
* in a warp must execute them. The old code guarded TMEM ops with
* `if (tid == 0)`, causing only lane 0 to execute = warp divergence
* on a collective op = HANG.
*
* 2. tmem_dealloc was passing the SMEM pointer instead of tmem_base
* (the value WRITTEN to SMEM by tcgen05.alloc).
*
* 3. The TMEM fence was already fixed: tcgen05.wait::st.sync.aligned
* and tcgen05.wait::ld.sync.aligned (the old tcgen05.fence doesn't exist).
*
* ==================================================================
* DESIGN: Warp-collective TMEM with scalar computation
* ==================================================================
*
* Thread 0 computes the attention loop (QK, softmax, P@V) and writes
* intermediate values to SMEM buffers. Warp 0 (all 32 lanes) then
* performs TMEM load/modify/store collectively. This ensures:
* - Correctness: same math as the reference kernel
* - No warp divergence on collective ops
* - TMEM is used as the accumulator (the whole point of Phase 2)
*
* For a single-column case (hd<=4), we still allocate 32 TMEM columns
* (minimum for tcgen05.alloc) but only use the first ceil(HD/4).
*
* ==================================================================
* TMEM LAYOUT (for tcgen05.ld/st 16x256b.x1.b32)
* ==================================================================
*
* Each tcgen05.ld/st operates on one "column" of TMEM. A column holds
* 16 rows × 256 bits = 16 × 8 × 32-bit registers = 4 uint32_t per lane.
* But since this is warp-collective, the 4 uint32_t per lane across 32 lanes
* gives 128 uint32_t per column, covering 16 rows × 8 FP32 per row.
*
* For T=1 decode, we only care about row 0. Lane 0's 4 registers map to
* 4 FP32 values in row 0. So for HD head_dim values, we need
* ceil(HD/4) columns, accessed at column indices 0, 1, 2, ...
*
* Column address = tmem_base + column_index.
* tmem_base is the value written to SMEM by tcgen05.alloc (typically 0).
* TMEM lane mapping (verified on B200 via test_tmem_lane_mapping.cu):
* tcgen05.st/ld 16x256b.x1.b32 is warp-collective. Each lane i
* writes/reads positions i*4+0..i*4+3 within the column.
* 32 lanes × 4 FP32 = 128 FP32 per column.
* For row 0: lane 0 = positions 0-3, lane 1 = 4-7, ..., lane 31 = 124-127.
*/
#pragma once
#include "fmha_common.cuh"
@@ -70,48 +38,38 @@ fmha_decode_tmem(
const bf16_t* vb = v + batch*bstride_kv;
bf16_t* oh = o + batch*bstride_o + head*HD;
// TMEM layout (verified on B200 via test_tmem_lane_mapping.cu):
// tcgen05.st 16x256b.x1.b32: each lane i writes 4 FP32 to positions
// i*4+0..i*4+3 within the column. 32 lanes × 4 = 128 FP32 per column.
// For row 0, lane 0 = positions 0-3, lane 1 = 4-7, ..., lane 31 = 124-127.
// So HD values need ceil(HD/128) columns.
constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128; // 1 for HD<=128, 2 for HD<=256
// tcgen05.alloc requires power-of-2 columns, minimum 32
constexpr int TMEM_N = TMEM_COLS_NEEDED <= 32 ? 32 :
(TMEM_COLS_NEEDED <= 64 ? 64 :
(TMEM_COLS_NEEDED <= 128 ? 128 : 256));
// TMEM column count: 128 FP32 per column, need ceil(HD/128)
constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128;
// tcgen05.alloc: power-of-2, minimum 32
constexpr int TMEM_N = TMEM_COLS_NEEDED <= 32 ? 32 : 64;
// SMEM layout:
// [0..3] tmem_base (written by tcgen05.alloc)
// [4..4+HD*4) sQ (HD floats)
// [4+HD*4..4+HD*4+4) sRowSums (1 float)
// [4+HD*4+8..) sPvBuf (4 floats for P@V intermediate)
// [0..3] tmem_base (written by tcgen05.alloc)
// [4..4+HD*4) sQ (HD floats)
// [4+HD*4..+8) sRowSums (1 float)
// [8+HD*4..] sO (HD floats — attention accumulator)
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
float* sQ = (float*)(sbuf + sizeof(uint32_t));
float* sRowSums = (float*)(sbuf + sizeof(uint32_t) + HD * sizeof(float));
float* sPvBuf = (float*)(sbuf + sizeof(uint32_t) + (HD + 1) * sizeof(float));
float* sO = (float*)(sbuf + sizeof(uint32_t) + (HD + 1) * sizeof(float));
// Load Q to SMEM
for (int d = tid; d < HD; d += NTHREADS) sQ[d] = bf16_to_f32(qh[d]);
// Load Q to SMEM + init accumulator
for (int d = tid; d < HD; d += NTHREADS) {
sQ[d] = bf16_to_f32(qh[d]);
sO[d] = 0.0f;
}
__syncthreads();
// ================================================================
// TMEM Allocation — warp-collective (all 32 lanes of warp 0)
// ================================================================
// TMEM alloc — warp-collective (all 32 lanes of warp 0)
if (wid == 0) {
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
tmem_alloc(smem_ptr, TMEM_N);
}
__syncthreads();
// Read tmem_base from SMEM (written by alloc)
uint32_t tmem_base = *sTmemBase;
if (tid == 0) printf("[tmem] base=%u, alloc_n=%d, cols_needed=%d\n", tmem_base, TMEM_N, TMEM_COLS_NEEDED);
// Initialize TMEM O to zero — warp-collective
// Use TMEM_COLS_NEEDED columns, each zeroed by all 32 lanes writing 0.
// TMEM_COLS_NEEDED defined above
// Zero TMEM columns — warp-collective
if (wid == 0) {
for (int col = 0; col < TMEM_COLS_NEEDED; col++) {
tmem_store(tmem_base + col, 0, 0, 0, 0);
@@ -119,73 +77,12 @@ fmha_decode_tmem(
tmem_fence_store();
}
__syncthreads();
if (tid == 0) printf("[tmem] zeroed OK\n");
// ================================================================
// Attention computation — thread 0 does the math, warp 0 does TMEM
// Attention computation in SMEM (same as reference kernel)
// ================================================================
float row_max = -INFINITY, row_sum = 0.0f;
// Thread 0 computes the full attention loop
if (tid == 0) {
for (int c = 0; c < s_k; c++) {
// QK dot product
float s_val = 0.0f;
for (int d = 0; d < HD; d++) s_val += sQ[d] * bf16_to_f32(kb[c * HD + d]);
s_val *= scale;
// D3: SWA mask
if (swa_len > 0 && c >= n_comp + swa_len) s_val = -INFINITY;
float new_max = fmaxf(row_max, s_val);
if (new_max > row_max) {
float rescale = expf(row_max - new_max);
// D1.5: Rescale O in TMEM
// Write rescale factor to SMEM so warp 0 can use it collectively
sPvBuf[0] = rescale;
sPvBuf[1] = 1.0f; // signal: rescale needed
// Wake warp 0 to do the rescale
__threadfence_block(); // ensure SMEM writes visible
}
float p_val = expf(s_val - row_max);
row_sum += p_val;
// Compute p_val * V[:,c] and write to SMEM for warp-collective TMEM store
for (int d = 0; d < HD; d++) {
sPvBuf[d] = p_val * bf16_to_f32(vb[d * s_k + c]);
}
__threadfence_block();
}
sRowSums[0] = row_sum;
}
// The above approach still has the warp-collective problem: thread 0
// writes to SMEM but warp 0 needs to read it and do TMEM ops.
// This requires a synchronization scheme that's complex.
//
// SIMPLER APPROACH: Do the attention loop in SMEM (like the reference
// kernel), then do a ONE-WAY TMEM write at the end. The TMEM is used
// for the correction epilogue, not for mid-loop accumulation.
// This proves the TMEM → regs → normalize → GMEM pipeline works.
//
// Once this works, we can move the accumulation into TMEM properly
// with a warp-cooperative design.
// Actually, let me take the simplest correct approach that proves
// the TMEM pipeline: compute in SMEM (reference style), then
// write result to TMEM, read back from TMEM, normalize, output.
// This tests the one-way correction epilogue path.
// Reset — compute attention in SMEM (same as reference)
for (int d = tid; d < HD; d += NTHREADS) sPvBuf[d] = 0.0f;
__syncthreads();
row_max = -INFINITY;
row_sum = 0.0f;
if (tid == 0) {
for (int c = 0; c < s_k; c++) {
float s_val = 0.0f;
@@ -196,67 +93,45 @@ fmha_decode_tmem(
float new_max = fmaxf(row_max, s_val);
if (new_max > row_max) {
float rescale = expf(row_max - new_max);
for (int d = 0; d < HD; d++) sPvBuf[d] *= rescale;
for (int d = 0; d < HD; d++) sO[d] *= rescale;
row_sum *= rescale;
row_max = new_max;
}
float p_val = expf(s_val - row_max);
row_sum += p_val;
for (int d = 0; d < HD; d++) sPvBuf[d] += p_val * bf16_to_f32(vb[d * s_k + c]);
for (int d = 0; d < HD; d++) sO[d] += p_val * bf16_to_f32(vb[d * s_k + c]);
}
sRowSums[0] = row_sum;
}
__syncthreads();
if (tid == 0) printf("[tmem] attention computed, row_sum=%f, sPvBuf[0]=%f, sPvBuf[32]=%f\n", sRowSums[0], sPvBuf[0], HD>32?sPvBuf[32]:0.0f);
// ================================================================
// One-way Correction Epilogue: SMEM → TMEM → regs → normalize → GMEM
//
// This is the production pipeline that the MoE kernel uses:
// 1. Write accumulator to TMEM (warp-collective store)
// 2. Read from TMEM to registers (warp-collective load)
// 3. Normalize in registers (per-lane math)
// 4. Cast to BF16 and write to GMEM
//
// Steps 1-2 prove the TMEM round-trip works (one-way, not
// the broken Ld32x32bOp/St32x32bOp from CuTeDSL).
// ================================================================
// Step 1: Write SMEM accumulator to TMEM (warp 0, warp-collective)
//
// TMEM lane mapping (verified on B200 via test_tmem_lane_mapping.cu):
// tcgen05.st 16x256b.x1.b32: each lane i writes 4 FP32 to positions
// i*4+0..i*4+3 within the column. 32 lanes × 4 = 128 FP32 per column.
// For row 0, lane 0 = positions 0-3, lane 1 = 4-7, ..., lane 31 = 124-127.
//
// So for HD<=128, a single column suffices (128 FP32 per col, row 0).
// Lane i writes sPvBuf[i*4+0..3] to column 0.
// Lanes with i*4 >= HD write zeros (don't-care, but must participate).
//
// Lane i writes sO[i*4+0..3] to column (i*4) / 128.
// All 32 lanes must call tmem_store (warp-collective).
if (wid == 0) {
for (int col = 0; col < TMEM_COLS_NEEDED; col++) {
int base = col * 128; // FP32 offset for this column
int base = col * 128;
int d0 = base + lane * 4 + 0;
int d1 = base + lane * 4 + 1;
int d2 = base + lane * 4 + 2;
int d3 = base + lane * 4 + 3;
uint32_t u0 = (d0 < HD) ? f32_to_u32(sPvBuf[d0]) : 0;
uint32_t u1 = (d1 < HD) ? f32_to_u32(sPvBuf[d1]) : 0;
uint32_t u2 = (d2 < HD) ? f32_to_u32(sPvBuf[d2]) : 0;
uint32_t u3 = (d3 < HD) ? f32_to_u32(sPvBuf[d3]) : 0;
uint32_t u0 = (d0 < HD) ? f32_to_u32(sO[d0]) : 0;
uint32_t u1 = (d1 < HD) ? f32_to_u32(sO[d1]) : 0;
uint32_t u2 = (d2 < HD) ? f32_to_u32(sO[d2]) : 0;
uint32_t u3 = (d3 < HD) ? f32_to_u32(sO[d3]) : 0;
tmem_store(tmem_base + col, u0, u1, u2, u3);
}
tmem_fence_store();
}
__syncthreads();
if (tid == 0) printf("[tmem] wrote to TMEM OK\n");
// Step 2: Read from TMEM to registers (warp 0, warp-collective)
// Same lane mapping: lane i reads positions i*4+0..3 from the column.
if (wid == 0) {
float inv_sum = 1.0f / sRowSums[0];
@@ -270,9 +145,6 @@ fmha_decode_tmem(
float r2 = u32_to_f32(u2) * inv_sum;
float r3 = u32_to_f32(u3) * inv_sum;
// Debug: print first few values from lane 0
if (lane == 0 && col == 0) printf("[tmem] read: r0=%f r1=%f r2=%f r3=%f inv_sum=%f\n", r0, r1, r2, r3, inv_sum);
// Step 4: Cast to BF16 and write to GMEM
int base = col * 128;
int d0 = base + lane * 4 + 0;
@@ -293,10 +165,7 @@ fmha_decode_tmem(
lse_out[batch * gridDim.y + head] = logf(row_sum) + row_max;
}
// ================================================================
// TMEM Deallocation — warp-collective (all 32 lanes of warp 0)
// Pass tmem_base (value written to SMEM by alloc), NOT the SMEM pointer
// ================================================================
// TMEM dealloc — warp-collective (pass tmem_base, not SMEM pointer)
if (wid == 0) {
tmem_dealloc(tmem_base, TMEM_N);
}