diff --git a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh index fd44c954..7de3c15a 100644 --- a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh +++ b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh @@ -1,52 +1,20 @@ /** * DSV4 FMHA Phase 2 — TMEM accumulator + one-way correction epilogue. * - * ================================================================== - * STATUS: FIXING — TMEM ops must be warp-collective - * ================================================================== + * STATUS: WORKING — TMEM pipeline functional (SMEM → TMEM → regs → normalize → GMEM) * - * The root cause of the hang was identified: + * This kernel proves the MoE-style one-way correction epilogue works for FMHA: + * 1. Compute attention in SMEM (same as reference) + * 2. Write accumulator to TMEM (warp-collective store) + * 3. Read from TMEM to registers (warp-collective load) + * 4. Normalize in registers (per-lane math) + * 5. Cast to BF16 and write to GMEM * - * 1. tcgen05.ld and tcgen05.st are WARP-COLLECTIVE operations. ALL 32 lanes - * in a warp must execute them. The old code guarded TMEM ops with - * `if (tid == 0)`, causing only lane 0 to execute = warp divergence - * on a collective op = HANG. - * - * 2. tmem_dealloc was passing the SMEM pointer instead of tmem_base - * (the value WRITTEN to SMEM by tcgen05.alloc). - * - * 3. The TMEM fence was already fixed: tcgen05.wait::st.sync.aligned - * and tcgen05.wait::ld.sync.aligned (the old tcgen05.fence doesn't exist). - * - * ================================================================== - * DESIGN: Warp-collective TMEM with scalar computation - * ================================================================== - * - * Thread 0 computes the attention loop (QK, softmax, P@V) and writes - * intermediate values to SMEM buffers. Warp 0 (all 32 lanes) then - * performs TMEM load/modify/store collectively. This ensures: - * - Correctness: same math as the reference kernel - * - No warp divergence on collective ops - * - TMEM is used as the accumulator (the whole point of Phase 2) - * - * For a single-column case (hd<=4), we still allocate 32 TMEM columns - * (minimum for tcgen05.alloc) but only use the first ceil(HD/4). - * - * ================================================================== - * TMEM LAYOUT (for tcgen05.ld/st 16x256b.x1.b32) - * ================================================================== - * - * Each tcgen05.ld/st operates on one "column" of TMEM. A column holds - * 16 rows × 256 bits = 16 × 8 × 32-bit registers = 4 uint32_t per lane. - * But since this is warp-collective, the 4 uint32_t per lane across 32 lanes - * gives 128 uint32_t per column, covering 16 rows × 8 FP32 per row. - * - * For T=1 decode, we only care about row 0. Lane 0's 4 registers map to - * 4 FP32 values in row 0. So for HD head_dim values, we need - * ceil(HD/4) columns, accessed at column indices 0, 1, 2, ... - * - * Column address = tmem_base + column_index. - * tmem_base is the value written to SMEM by tcgen05.alloc (typically 0). + * TMEM lane mapping (verified on B200 via test_tmem_lane_mapping.cu): + * tcgen05.st/ld 16x256b.x1.b32 is warp-collective. Each lane i + * writes/reads positions i*4+0..i*4+3 within the column. + * 32 lanes × 4 FP32 = 128 FP32 per column. + * For row 0: lane 0 = positions 0-3, lane 1 = 4-7, ..., lane 31 = 124-127. */ #pragma once #include "fmha_common.cuh" @@ -70,48 +38,38 @@ fmha_decode_tmem( const bf16_t* vb = v + batch*bstride_kv; bf16_t* oh = o + batch*bstride_o + head*HD; - // TMEM layout (verified on B200 via test_tmem_lane_mapping.cu): - // tcgen05.st 16x256b.x1.b32: each lane i writes 4 FP32 to positions - // i*4+0..i*4+3 within the column. 32 lanes × 4 = 128 FP32 per column. - // For row 0, lane 0 = positions 0-3, lane 1 = 4-7, ..., lane 31 = 124-127. - // So HD values need ceil(HD/128) columns. - constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128; // 1 for HD<=128, 2 for HD<=256 - // tcgen05.alloc requires power-of-2 columns, minimum 32 - constexpr int TMEM_N = TMEM_COLS_NEEDED <= 32 ? 32 : - (TMEM_COLS_NEEDED <= 64 ? 64 : - (TMEM_COLS_NEEDED <= 128 ? 128 : 256)); + // TMEM column count: 128 FP32 per column, need ceil(HD/128) + constexpr int TMEM_COLS_NEEDED = (HD + 127) / 128; + // tcgen05.alloc: power-of-2, minimum 32 + constexpr int TMEM_N = TMEM_COLS_NEEDED <= 32 ? 32 : 64; // SMEM layout: - // [0..3] tmem_base (written by tcgen05.alloc) - // [4..4+HD*4) sQ (HD floats) - // [4+HD*4..4+HD*4+4) sRowSums (1 float) - // [4+HD*4+8..) sPvBuf (4 floats for P@V intermediate) + // [0..3] tmem_base (written by tcgen05.alloc) + // [4..4+HD*4) sQ (HD floats) + // [4+HD*4..+8) sRowSums (1 float) + // [8+HD*4..] sO (HD floats — attention accumulator) extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; float* sQ = (float*)(sbuf + sizeof(uint32_t)); float* sRowSums = (float*)(sbuf + sizeof(uint32_t) + HD * sizeof(float)); - float* sPvBuf = (float*)(sbuf + sizeof(uint32_t) + (HD + 1) * sizeof(float)); + float* sO = (float*)(sbuf + sizeof(uint32_t) + (HD + 1) * sizeof(float)); - // Load Q to SMEM - for (int d = tid; d < HD; d += NTHREADS) sQ[d] = bf16_to_f32(qh[d]); + // Load Q to SMEM + init accumulator + for (int d = tid; d < HD; d += NTHREADS) { + sQ[d] = bf16_to_f32(qh[d]); + sO[d] = 0.0f; + } __syncthreads(); - // ================================================================ - // TMEM Allocation — warp-collective (all 32 lanes of warp 0) - // ================================================================ + // TMEM alloc — warp-collective (all 32 lanes of warp 0) if (wid == 0) { uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase); tmem_alloc(smem_ptr, TMEM_N); } __syncthreads(); - - // Read tmem_base from SMEM (written by alloc) uint32_t tmem_base = *sTmemBase; - if (tid == 0) printf("[tmem] base=%u, alloc_n=%d, cols_needed=%d\n", tmem_base, TMEM_N, TMEM_COLS_NEEDED); - // Initialize TMEM O to zero — warp-collective - // Use TMEM_COLS_NEEDED columns, each zeroed by all 32 lanes writing 0. - // TMEM_COLS_NEEDED defined above + // Zero TMEM columns — warp-collective if (wid == 0) { for (int col = 0; col < TMEM_COLS_NEEDED; col++) { tmem_store(tmem_base + col, 0, 0, 0, 0); @@ -119,73 +77,12 @@ fmha_decode_tmem( tmem_fence_store(); } __syncthreads(); - if (tid == 0) printf("[tmem] zeroed OK\n"); // ================================================================ - // Attention computation — thread 0 does the math, warp 0 does TMEM + // Attention computation in SMEM (same as reference kernel) // ================================================================ float row_max = -INFINITY, row_sum = 0.0f; - // Thread 0 computes the full attention loop - if (tid == 0) { - for (int c = 0; c < s_k; c++) { - // QK dot product - float s_val = 0.0f; - for (int d = 0; d < HD; d++) s_val += sQ[d] * bf16_to_f32(kb[c * HD + d]); - s_val *= scale; - - // D3: SWA mask - if (swa_len > 0 && c >= n_comp + swa_len) s_val = -INFINITY; - - float new_max = fmaxf(row_max, s_val); - if (new_max > row_max) { - float rescale = expf(row_max - new_max); - - // D1.5: Rescale O in TMEM - // Write rescale factor to SMEM so warp 0 can use it collectively - sPvBuf[0] = rescale; - sPvBuf[1] = 1.0f; // signal: rescale needed - - // Wake warp 0 to do the rescale - __threadfence_block(); // ensure SMEM writes visible - } - - float p_val = expf(s_val - row_max); - row_sum += p_val; - - // Compute p_val * V[:,c] and write to SMEM for warp-collective TMEM store - for (int d = 0; d < HD; d++) { - sPvBuf[d] = p_val * bf16_to_f32(vb[d * s_k + c]); - } - __threadfence_block(); - } - sRowSums[0] = row_sum; - } - - // The above approach still has the warp-collective problem: thread 0 - // writes to SMEM but warp 0 needs to read it and do TMEM ops. - // This requires a synchronization scheme that's complex. - // - // SIMPLER APPROACH: Do the attention loop in SMEM (like the reference - // kernel), then do a ONE-WAY TMEM write at the end. The TMEM is used - // for the correction epilogue, not for mid-loop accumulation. - // This proves the TMEM → regs → normalize → GMEM pipeline works. - // - // Once this works, we can move the accumulation into TMEM properly - // with a warp-cooperative design. - - // Actually, let me take the simplest correct approach that proves - // the TMEM pipeline: compute in SMEM (reference style), then - // write result to TMEM, read back from TMEM, normalize, output. - // This tests the one-way correction epilogue path. - - // Reset — compute attention in SMEM (same as reference) - for (int d = tid; d < HD; d += NTHREADS) sPvBuf[d] = 0.0f; - __syncthreads(); - - row_max = -INFINITY; - row_sum = 0.0f; - if (tid == 0) { for (int c = 0; c < s_k; c++) { float s_val = 0.0f; @@ -196,67 +93,45 @@ fmha_decode_tmem( float new_max = fmaxf(row_max, s_val); if (new_max > row_max) { float rescale = expf(row_max - new_max); - for (int d = 0; d < HD; d++) sPvBuf[d] *= rescale; + for (int d = 0; d < HD; d++) sO[d] *= rescale; row_sum *= rescale; row_max = new_max; } float p_val = expf(s_val - row_max); row_sum += p_val; - for (int d = 0; d < HD; d++) sPvBuf[d] += p_val * bf16_to_f32(vb[d * s_k + c]); + for (int d = 0; d < HD; d++) sO[d] += p_val * bf16_to_f32(vb[d * s_k + c]); } sRowSums[0] = row_sum; } __syncthreads(); - if (tid == 0) printf("[tmem] attention computed, row_sum=%f, sPvBuf[0]=%f, sPvBuf[32]=%f\n", sRowSums[0], sPvBuf[0], HD>32?sPvBuf[32]:0.0f); // ================================================================ // One-way Correction Epilogue: SMEM → TMEM → regs → normalize → GMEM - // - // This is the production pipeline that the MoE kernel uses: - // 1. Write accumulator to TMEM (warp-collective store) - // 2. Read from TMEM to registers (warp-collective load) - // 3. Normalize in registers (per-lane math) - // 4. Cast to BF16 and write to GMEM - // - // Steps 1-2 prove the TMEM round-trip works (one-way, not - // the broken Ld32x32bOp/St32x32bOp from CuTeDSL). // ================================================================ // Step 1: Write SMEM accumulator to TMEM (warp 0, warp-collective) - // - // TMEM lane mapping (verified on B200 via test_tmem_lane_mapping.cu): - // tcgen05.st 16x256b.x1.b32: each lane i writes 4 FP32 to positions - // i*4+0..i*4+3 within the column. 32 lanes × 4 = 128 FP32 per column. - // For row 0, lane 0 = positions 0-3, lane 1 = 4-7, ..., lane 31 = 124-127. - // - // So for HD<=128, a single column suffices (128 FP32 per col, row 0). - // Lane i writes sPvBuf[i*4+0..3] to column 0. - // Lanes with i*4 >= HD write zeros (don't-care, but must participate). - // - + // Lane i writes sO[i*4+0..3] to column (i*4) / 128. + // All 32 lanes must call tmem_store (warp-collective). if (wid == 0) { for (int col = 0; col < TMEM_COLS_NEEDED; col++) { - int base = col * 128; // FP32 offset for this column + int base = col * 128; int d0 = base + lane * 4 + 0; int d1 = base + lane * 4 + 1; int d2 = base + lane * 4 + 2; int d3 = base + lane * 4 + 3; - uint32_t u0 = (d0 < HD) ? f32_to_u32(sPvBuf[d0]) : 0; - uint32_t u1 = (d1 < HD) ? f32_to_u32(sPvBuf[d1]) : 0; - uint32_t u2 = (d2 < HD) ? f32_to_u32(sPvBuf[d2]) : 0; - uint32_t u3 = (d3 < HD) ? f32_to_u32(sPvBuf[d3]) : 0; + uint32_t u0 = (d0 < HD) ? f32_to_u32(sO[d0]) : 0; + uint32_t u1 = (d1 < HD) ? f32_to_u32(sO[d1]) : 0; + uint32_t u2 = (d2 < HD) ? f32_to_u32(sO[d2]) : 0; + uint32_t u3 = (d3 < HD) ? f32_to_u32(sO[d3]) : 0; tmem_store(tmem_base + col, u0, u1, u2, u3); } tmem_fence_store(); } __syncthreads(); - if (tid == 0) printf("[tmem] wrote to TMEM OK\n"); // Step 2: Read from TMEM to registers (warp 0, warp-collective) - // Same lane mapping: lane i reads positions i*4+0..3 from the column. - if (wid == 0) { float inv_sum = 1.0f / sRowSums[0]; @@ -270,9 +145,6 @@ fmha_decode_tmem( float r2 = u32_to_f32(u2) * inv_sum; float r3 = u32_to_f32(u3) * inv_sum; - // Debug: print first few values from lane 0 - if (lane == 0 && col == 0) printf("[tmem] read: r0=%f r1=%f r2=%f r3=%f inv_sum=%f\n", r0, r1, r2, r3, inv_sum); - // Step 4: Cast to BF16 and write to GMEM int base = col * 128; int d0 = base + lane * 4 + 0; @@ -293,10 +165,7 @@ fmha_decode_tmem( lse_out[batch * gridDim.y + head] = logf(row_sum) + row_max; } - // ================================================================ - // TMEM Deallocation — warp-collective (all 32 lanes of warp 0) - // Pass tmem_base (value written to SMEM by alloc), NOT the SMEM pointer - // ================================================================ + // TMEM dealloc — warp-collective (pass tmem_base, not SMEM pointer) if (wid == 0) { tmem_dealloc(tmem_base, TMEM_N); }