diff --git a/CURRENT_ISSUE.md b/CURRENT_ISSUE.md index a9257af3..aa3dc37f 100644 --- a/CURRENT_ISSUE.md +++ b/CURRENT_ISSUE.md @@ -1,6 +1,9 @@ # CURRENT_ISSUE.md — FMHA 6-Warp Specialization -## Status: Milestone 5 COMPLETE ✅ (multi-head grid launch with MHA/MQA/batch) +## Status: Milestone 4 IN PROGRESS (multi-row softmax for prefill T>1) +### Milestone 5 ✅ DONE — multi-head grid launch +### Milestone 4: T≤32 PASSING (cos 0.999996+), T>32 BLOCKED on TMEM row read +### CRITICAL BUG FIXED: Q/K SMEM canonical layout used full_d instead of local d (0..15) ### What works: - **6-warp kernel**: Warps 0-3 softmax/epilogue, Warp 4 MMA, Warp 5 data staging @@ -12,6 +15,7 @@ - **Batched**: blockIdx.z for batch dimension - **LSE output**: per-row LSE for multi-segment KV merge - **FmhaParams struct**: stride-based tensor addressing, future-proof for GQA +- **Multi-row softmax T≤32**: cos 0.999996+ with per-lane per-row softmax (no wmax/wsum) ### Architecture: ``` @@ -37,9 +41,15 @@ Warp 5 (tid 160-191): Data staging - Load next K/V while computing current QK - mbarrier producer-consumer sync between warp 5 and warp 4 - Depends on TMA loads (Milestone 2) -3. **Multi-row softmax** (Milestone 4): Process all 128 rows (prefill T>1) - - All 4 softmax warps process rows in parallel - - Warp w handles rows [w*32, (w+1)*32) ∩ [0, T) +3. **Multi-row softmax** (Milestone 4): Process all 128 rows (prefill T>1) 🚧 IN PROGRESS + - T≤32: WORKING — warp 0, lane l handles row l, 32x32b.x8 TMEM reads + - T>32: BLOCKED — 32x32b.x8 only reads rows 0-31 + - NEXT: Use 16x256b.x1 TMEM reads (reads all 128 rows per column) + - Each of 4 softmax warps handles rows [w*32, (w+1*32) ∩ [0, T) + - Per-lane row assignment in 16x256b: lane j gets rows j*4+0..3 + - No cross-warp reduction needed (disjoint row sets) + - KEY LESSON: Q/K SMEM canonical positions MUST use local d (0..15), NOT full_d + The UMMA descriptor always reads from sQ0/sK0 start, not offset 4. ~~**Multi-head launch** (Milestone 5): grid=(1, n_h, batch)~~ ✅ DONE 5. **Production integration** (Milestone 6): Hook into production.py diff --git a/CURRENT_ISSUE_FROM_OUTSIDE_CONSULTANT.md b/CURRENT_ISSUE_FROM_OUTSIDE_CONSULTANT.md index 108294c7..11650921 100644 --- a/CURRENT_ISSUE_FROM_OUTSIDE_CONSULTANT.md +++ b/CURRENT_ISSUE_FROM_OUTSIDE_CONSULTANT.md @@ -1,6 +1,6 @@ # ISSUE — Lightning Indexer FP4 dequant decodes E2M1 wrong -**Status:** FIXED ✅ — E2M1 LUT fix landed in both `dsv4/kernels/indexer/indexer_score_topk.cu` and `dsv4/kernels/cuda/indexer_score_topk.cu`. +**Status:** FIXED ✅ — E2M1 LUT fix landed in both `dsv4/kernels/indexer/indexer_score_topk.cu` and `dsv4/kernels/cuda/indexer_score_topk.cu`. Crossed off the list. **Severity:** Was HIGH. Corrupts top-k *selection*, which is the whole job of the indexer. **Scope:** `dsv4/kernels/indexer/indexer_score_topk.cu` and the duplicate `dsv4/kernels/cuda/indexer_score_topk.cu`. Does NOT touch FMHA, MoE, or the GEMM stack. diff --git a/PRIORITY2.md b/PRIORITY2.md new file mode 100644 index 00000000..28292577 --- /dev/null +++ b/PRIORITY2.md @@ -0,0 +1,175 @@ +# ISSUE — Priority 2: kill the Python KV merge; one-way in-kernel epilogue + online softmax + +**Status:** OPEN — structural shortcut on the FMHA decode hot path. +**Severity:** HIGH for "production grade." Correctness is fine (cos 0.999998); the +problem is this path can never be fast, and it blocks the entire multi-CTA / FP4 +fusion chain (D2/P4, NVFP4-1.2, NVFP4-2). +**Scope:** `dsv4/kernels/attention/production.py` (`_run_fmha_segmented`), +`dsv4/kernels/attention/fmha.py` (epilogue), the WIP `fmha_smem_acc.py`, and the +raw-CUDA `fmha_sm100_tc.cuh`. + +--- + +## TL;DR + +Multi-KV-tile attention is currently done by launching the single-tile CuTeDSL +kernel once per (segment × pv_tile), copying un-normalized O + LSE back to the +host-side allocator, and merging tiles with **eager PyTorch ops** — with a full +**`torch.cuda.synchronize()` inside the inner loop**. That is orchestration, not a +kernel. The fix is the standard FlashAttention shape: **loop KV tiles inside the +kernel with running max/sum (online softmax), and write the result once through a +one-way TMEM→regs→SMEM→GMEM epilogue** — the exact pattern the MoE GEMM already +runs correctly on Blackwell. + +This is the load-bearing example of doctrine rule 1: hitting a CuTeDSL wall is not +a license to fall back to Python. + +--- + +## The shortcut, measured (not estimated) + +`dsv4/kernels/attention/production.py`, `_run_fmha_segmented`: + +```python +for seg in range(n_segments): + ... + k_seg = torch.cat([k_seg, torch.zeros(...)]) # alloc + copy to pad + seg_o = torch.zeros(M, hd, ...) # per-seg allocs + for nt in range(n_pv_tiles): + c_tile = torch.zeros(M, pv_n_tile, 1, ...) # per-tile allocs + lse_tensor = torch.zeros(M, 1, 1, ...) + mQ = ct.from_dlpack(q_3d).mark_layout_dynamic(...) # descriptor rebuild + ... # every iteration + compiled(mQ, mK, mV, mC, stream, lse=mLSE, ...) # 1 launch + torch.cuda.synchronize() # <-- FULL DEVICE SYNC + seg_o[:, v_start:v_end] = c_tile[...].float() + # eager exp/log/div merge (≈5 launches) + o_accum = (e_old*o_accum + e_new*seg_o_norm) / e_sum +``` + +Per-call cost on the hot path, every iteration: +- 1 kernel launch **+ 1 full `cudaDeviceSynchronize`** (serializes the stream — + no overlap, no pipelining, no async) +- 6+ `torch.zeros` allocations (allocator churn) +- `torch.cat` to pad the partial tail segment (alloc + copy) +- `ct.from_dlpack(...).mark_layout_dynamic(...)` ×6 (host-side descriptor rebuild) +- ~5 eager merge ops per segment + +**Launch/sync count, V4-Pro CSA decode, single token, single layer:** +`top_k = 1024` → `n_segments = 1024/128 = 8`; `hd = 512` → `n_pv_tiles = 512/128 = 4`. +→ **32 kernel launches + 32 device syncs** per CSA layer, plus ~40 eager merge ops. +Across the interleaved CSA/HCA layers of a 61-layer model that is **~1–2k launches +and ~1–2k device syncs per decoded token.** At single-digit-µs launch+sync each, +that's milliseconds of pure overhead per token — a hard ceiling of tens of tok/s +before any real compute runs. At 1M context (KV-read bound, the entire point of +V4) this is the bottleneck, full stop. + +Head-packing (MQA → all 128 Q heads in M) is already correct and keeps this from +also multiplying by heads — good. The segment/tile loop and the per-iteration sync +are the problem. + +--- + +## The blocker + +`dsv4/kernels/attention/fmha.py` (~line 620) writes O via +`utils.gemm.sm100.epilogue_tma_store`, which reads O straight from TMEM and +TMA-stores it. Per the file header (point 2): it **cannot accept flat_divide-based +GMEM coordinates**, so you can't fold (batch, head, kv_tile) into a multi-CTA grid. +That single epilogue choice is why multi-tile has to be done host-side, and why +D2/P4 (multi-CTA) and FP4 output fusion are all stuck behind it. + +The TMEM round-trip ("D1.5 fundamentally broken") is a **CuTeDSL** atom-pairing +limitation, not hardware — CUTLASS C++ pairs Ld/St atoms that work. So the answer +is not "merge in Python," it's "use the one-way epilogue (no round-trip) or write +the raw CUDA kernel." + +--- + +## The pattern to port FROM (already working on Blackwell) + +`dsv4/kernels/gemm/dense.py` runs the correct one-way epilogue every MoE GEMM: +TMEM → regs → SMEM → GMEM, no round-trip, multi-CTA-safe, with a register slot for +fusion. Reuse these symbols (already imported into `fmha.py`): + +- `sm100_utils.compute_epilogue_tile_shape` (line ~313) — epi subtile +- `sm100_utils.make_smem_layout_epi` (line ~361) — SMEM staging layout +- `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition` + (imported at fmha.py lines 71–72, currently unused on the O path) +- `epilogue_op: cutlass.Constexpr = lambda x: x` (line ~405) — the hook where FP4 + pack lands later (NVFP4-1.2) + +This is not new infrastructure. It's wiring the FMHA O-store to the epilogue the +MoE kernel already proves out. + +--- + +## Design of the fix + +**1. Online softmax across KV tiles, in-kernel.** Replace the host-side log-sum-exp +merge with a running `(row_max, row_sum)` rescale inside the KV-tile loop — standard +FlashAttention-2. The kernel already emits un-normalized O + LSE + row_sums for a +single tile; lift the accumulation into the tile loop so one launch consumes all +selected KV. No host merge, no per-tile sync, no per-tile allocs. + +**2. One-way epilogue.** After the tile loop, O lives in TMEM. Copy +TMEM→regs→SMEM→GMEM via the MoE epilogue path. No `epilogue_tma_store`, no TMEM +round-trip. Normalize by the final `row_sum` in registers before the SMEM store +(or keep emitting LSE if external normalization is still wanted — but it shouldn't +be needed once the merge is in-kernel). + +**3. Multi-CTA grid.** With the one-way epilogue accepting flat_divide coords, +fold (batch, kv_head_group, [q_tile]) into the grid. This is D2/P4. `use_2cta_instrs` +applies for prefill/batched shapes (M ≥ 256); decode single-token stays 1-CTA. + +**4. Preallocated scratch.** Whatever staging remains is allocated once at warmup +and reused (cudagraph-safe), never `torch.zeros` on the hot path. + +### Two ramps — pick per doctrine, not per convenience + +- **Ramp A (preferred if CuTeDSL cooperates):** finish the one-way epilogue in + `fmha_smem_acc.py` against the MoE pattern. Status of that file is unclear ("many + commits, unclear if working") — first action is to determine whether it works, + with a print-and-diff, not by reading the diff log. +- **Ramp B (when CuTeDSL walls — the legitimate fallback):** `fmha_sm100_tc.cuh` + is already the right shape — `tcgen05.mma` SS for QK, TS for PV, TMEM accumulators, + UMMA descriptors, and it already carries `sRowMax/sRowSum` running state for + in-kernel online softmax. Finish THIS, not a Python merge. It is raw CUDA C++, + it is Blackwell-native, and it has no `epilogue_tma_store` constraint because you + control the store. + +--- + +## Test plan — measure, don't eyeball + +1. **Correctness parity:** new in-kernel path vs current Python-merge path on the + same inputs across the matrix that already passes — hd ∈ {64,128,256,512}, + n_segments ∈ {1,2,8}, with SWA mask / causal / sink / n_comp. Gate: cos ≥ 0.999998 + (must match the path it replaces, this is a refactor not a numerics change). +2. **Launch + sync count:** Nsight Systems (or a CUPTI launch counter) on one + decoded token, V4-Pro CSA layer. Record launches and `cudaDeviceSynchronize` + calls before/after. Target: per-layer launches from 32 → 1 (decode), syncs from + 32 → 0 on the hot path. +3. **Latency:** per-token decode latency at context 8k / 128k / 1M, before/after. + This is the number that says whether Priority 1's "does the overhead matter" + question is answered. (Spoiler from the launch math: it matters.) +4. **Unblock check:** confirm the new epilogue accepts a multi-CTA grid (D2 smoke + test, M≥256 prefill) and exposes a register slot for the FP4 pack (NVFP4-1.2 + stub: `epilogue_op` that amax+packs, behind a flag, off by default). + +--- + +## DOCTRINE REMINDER (this issue is the cautionary tale) + +1. **CuTeDSL/CUTLASS wall → raw CUDA C++, NOT Python.** The Python KV merge is what + rule 1 exists to prevent. `fmha_sm100_tc.cuh` is the correct fallback. A host-side + loop with a per-iteration `cudaDeviceSynchronize` is never the production answer. +2. **Raw CUDA ≠ scalar math.** The fallback kernel stays `tcgen05`/UMMA/TMEM/TMA with + warp-level softmax reductions. Do not let an agent "simplify" the PV into scalar + FMA to get it compiling. +3. **Print, don't guess.** Before touching `fmha_smem_acc.py`, print its actual + output and diff against the reference — do not infer its status from commit + messages. When wiring the epilogue, print epi tile shape, SMEM layout, TMEM + offsets, and the MMA instruction shape at construction, and code to those values. + The epilogue is exactly the kind of layout surface where a guess passes a toy + test and corrupts at real shapes. \ No newline at end of file diff --git a/dsv4/kernels/attention/fmha_6warp_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_multirow.cuh index 0182a613..4b2bf684 100644 --- a/dsv4/kernels/attention/fmha_6warp_multirow.cuh +++ b/dsv4/kernels/attention/fmha_6warp_multirow.cuh @@ -1,15 +1,11 @@ /** * DSV4 FMHA — 6-warp specialized kernel, multi-row softmax (prefill T>1). * - * ================================================================== - * MULTI-ROW SOFTMAX (Milestone 4) - * ================================================================== - * For T ≤ 32: Uses 32x32b.x8 TMEM reads (same as decode). - * Each lane handles one row. Warp 0 processes rows 0..T-1. - * No cross-lane reduction needed (per-row max/sum in each lane). - * - * For T > 32: NOT YET IMPLEMENTED (will use 16x256b.x1). - * ================================================================== + * T <= 32: Warp 0 only, 32x32b.x8, lane l = row l. + * T > 32: 4 softmax warps, 16x256b.x1, two-pass online softmax. + * Warp w handles rows [w*32, min((w+1)*32, T)). + * Lane j reads rows j*4+0..3 per TMEM column. + * Active lanes for warp w: j in [w*8, w*8+8). */ #pragma once @@ -25,11 +21,9 @@ struct FmhaMultiRowParams { const bf16_t* __restrict__ v; bf16_t* __restrict__ o; float* __restrict__ lse; - int s_k, T; float scale; int head_dim; - int q_head_stride, q_batch_stride; int k_head_stride, k_batch_stride; int v_head_stride, v_batch_stride; @@ -53,88 +47,62 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { const int tid = threadIdx.x; const int wid = tid / 32; const int lane = tid % 32; - - const bool is_softmax_warp = (wid == 0); + const bool is_softmax_warp = (wid < 4); const bool is_mma_warp = (wid == 4); const bool is_load_warp = (wid == 5); - const int T = params.T; const int s_k = params.s_k; const float scale = params.scale; + const bool full_tile = (T > 32); - // ================================================================== - // Per-head GMEM pointers - // ================================================================== - const bf16_t* __restrict__ q_head = params.q - + head_idx * params.q_head_stride - + batch_idx * params.q_batch_stride; - const bf16_t* __restrict__ k_head = params.k - + head_idx * params.k_head_stride - + batch_idx * params.k_batch_stride; - const bf16_t* __restrict__ v_head = params.v - + head_idx * params.v_head_stride - + batch_idx * params.v_batch_stride; - bf16_t* __restrict__ o_head = params.o - + head_idx * params.o_head_stride - + batch_idx * params.o_batch_stride; - float* __restrict__ lse_head = params.lse - ? params.lse + head_idx * params.lse_head_stride - + batch_idx * params.lse_batch_stride - : nullptr; + const bf16_t* __restrict__ q_head = params.q + head_idx * params.q_head_stride + batch_idx * params.q_batch_stride; + const bf16_t* __restrict__ k_head = params.k + head_idx * params.k_head_stride + batch_idx * params.k_batch_stride; + const bf16_t* __restrict__ v_head = params.v + head_idx * params.v_head_stride + batch_idx * params.v_batch_stride; + bf16_t* __restrict__ o_head = params.o + head_idx * params.o_head_stride + batch_idx * params.o_batch_stride; + float* __restrict__ lse_head = params.lse ? params.lse + head_idx * params.lse_head_stride + batch_idx * params.lse_batch_stride : nullptr; - // ================================================================ - // SMEM allocation — SAME LAYOUT as multihead kernel for T=1 compat - // ================================================================ + // SMEM extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; - float* sRowMax = (float*)(sbuf + 4); // [32] per-warp rows - float* sRowSum = sRowMax + 32; // [32] - bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowSum + 32) + 15) & ~(uintptr_t)15); + float* sRowMax = (float*)(sbuf + 4); // [MAX_ROWS] + float* sRowSum = sRowMax + MAX_ROWS; // [MAX_ROWS] + bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowSum + MAX_ROWS) + 15) & ~(uintptr_t)15); bf16_t* sK0 = sQ0 + TILE_SZ; bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127); bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127); - float* s_p_vals = (float*)(sV + V_SUB_SZ); // [MAX_ROWS][SK_TILE] + float* s_p_vals = (float*)(sV + V_SUB_SZ); // [MAX_ROWS][SK_TILE] - // ================================================================ - // TMEM allocation - // ================================================================ - if (is_mma_warp) { - uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase); - tmem_alloc(smem_ptr, TMEM_N); - } + // TMEM alloc + if (is_mma_warp) { uint32_t p = __cvta_generic_to_shared(sTmemBase); tmem_alloc(p, TMEM_N); } __syncthreads(); uint32_t tb = *sTmemBase; // ================================================================ - // QK GEMM loop + // QK GEMM // ================================================================ for (int kt = 0; kt < NKT_QK; kt++) { if (is_load_warp) { constexpr int CORES_MN = 128 / 8; for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0; - // Q loading: write to same canonical positions for each K-tile - // The UMMA descriptor always reads from sQ0 start for (int r = 0; r < T; r++) { for (int d = lane; d < MMA_K_BF16; d += 32) { - int full_d = kt * MMA_K_BF16 + d; // GMEM index - int ck = d / 8, lc = d % 8; // canonical position (same for all kt) + int full_d = kt * MMA_K_BF16 + d; + int ck = d / 8, lc = d % 8; int core_mn = r / 8, local_r = r % 8; - sQ0[ck * CORES_MN * 64 + core_mn * 64 + local_r * 8 + lc] = - q_head[r * HD + full_d]; + sQ0[ck * CORES_MN * 64 + core_mn * 64 + local_r * 8 + lc] = q_head[r * HD + full_d]; } } for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0; for (int r = 0; r < s_k; r++) { for (int d = lane; d < MMA_K_BF16; d += 32) { - int full_d = kt * MMA_K_BF16 + d; // GMEM index - int ck = d / 8, lc = d % 8; // canonical position (same for all kt) + int full_d = kt * MMA_K_BF16 + d; + int ck = d / 8, lc = d % 8; int tmn = r / 8, lr = r % 8; sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k_head[r * HD + full_d]; } } } __syncthreads(); - if (is_mma_warp) { uint32_t idesc = make_idesc(128, 128); uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128); @@ -146,56 +114,90 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { } // ================================================================ - // Softmax (32x32b.x8, per-lane per-row) + // SOFTMAX // ================================================================ - // For T ≤ 32: lane l handles row l. Lanes l >= T have no data. - // No wmax/wsum — each lane computes its own row independently. - // ================================================================ - if (is_softmax_warp) { - float s_vals[SK_TILE], row_max = -INFINITY; - - // Read S from TMEM - for (int n = 0; n < SK_TILE / 8; n++) { - float tmp[8]; - asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" - : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), - "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) - : "r"(tb + n * 8)); - asm volatile("tcgen05.wait::ld.sync.aligned;"); - if (lane < T) { - for (int c = 0; c < 8; c++) { - s_vals[n * 8 + c] = tmp[c] * scale; - row_max = fmaxf(row_max, s_vals[n * 8 + c]); + if (!full_tile) { + // T <= 32: warp 0, 32x32b.x8, lane l = row l + if (wid == 0) { + float s_vals[SK_TILE], row_max = -INFINITY; + for (int n = 0; n < SK_TILE / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane < T) { + for (int c = 0; c < 8; c++) { + s_vals[n * 8 + c] = tmp[c] * scale; + row_max = fmaxf(row_max, s_vals[n * 8 + c]); + } } } - } - - // Store per-row max - if (lane < T) sRowMax[lane] = row_max; - - float row_sum = 0.0f; - if (lane < T) { - for (int j = 0; j < SK_TILE; j++) { - s_vals[j] = expf(s_vals[j] - row_max); - row_sum += s_vals[j]; + if (lane < T) sRowMax[lane] = row_max; + float row_sum = 0.0f; + if (lane < T) { + for (int j = 0; j < SK_TILE; j++) { s_vals[j] = expf(s_vals[j] - row_max); row_sum += s_vals[j]; } + } + if (lane < T) sRowSum[lane] = row_sum; + if (lane < T) { + float inv = 1.0f / row_sum; + for (int j = 0; j < SK_TILE; j++) s_p_vals[lane * SK_TILE + j] = s_vals[j] * inv; } } + } else { + // T > 32: 4 warps, 16x256b.x1, two-pass softmax + if (is_softmax_warp) { + const int warp_row_start = wid * 32; + const int warp_row_end = min((wid + 1) * 32, T); + const int lane_row_base = lane * 4; + const bool active = (lane_row_base >= warp_row_start) && (lane_row_base < warp_row_end); + const int vrows = active ? min(4, warp_row_end - lane_row_base) : 0; - // Store per-row sum - if (lane < T) sRowSum[lane] = row_sum; + float rmax[4], rsum[4]; + for (int r = 0; r < 4; r++) rmax[r] = -INFINITY; - // Normalize and write P to s_p_vals - if (lane < T) { - float inv_sum = 1.0f / row_sum; - for (int j = 0; j < SK_TILE; j++) { - s_p_vals[lane * SK_TILE + j] = s_vals[j] * inv_sum; + // Pass 1: row_max + for (int col = 0; col < SK_TILE; col++) { + float v[4]; + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0,%1,%2,%3},[%4,%5];" + : "=f"(v[0]), "=f"(v[1]), "=f"(v[2]), "=f"(v[3]) : "r"(tb), "r"(col)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (active) for (int r = 0; r < vrows; r++) rmax[r] = fmaxf(rmax[r], v[r] * scale); + } + + // Pass 2: exp, sum, write P + for (int r = 0; r < 4; r++) rsum[r] = 0.0f; + for (int col = 0; col < SK_TILE; col++) { + float v[4]; + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0,%1,%2,%3},[%4,%5];" + : "=f"(v[0]), "=f"(v[1]), "=f"(v[2]), "=f"(v[3]) : "r"(tb), "r"(col)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (active) { + for (int r = 0; r < vrows; r++) { + int gr = lane_row_base + r; + float p = expf(v[r] * scale - rmax[r]); + rsum[r] += p; + s_p_vals[gr * SK_TILE + col] = p; + } + } + } + // Normalize and store max/sum + if (active) { + for (int r = 0; r < vrows; r++) { + int gr = lane_row_base + r; + sRowMax[gr] = rmax[r]; + sRowSum[gr] = rsum[r]; + float inv = 1.0f / rsum[r]; + for (int j = 0; j < SK_TILE; j++) s_p_vals[gr * SK_TILE + j] *= inv; + } } } } __syncthreads(); // ================================================================ - // PV GEMM loop + // PV GEMM // ================================================================ for (int n = 0; n < N_NSUB; n++) { int d_base = n * 16; @@ -205,32 +207,27 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0; for (int r = 0; r < T; r++) { for (int c = lane; c < MMA_K_BF16; c += 32) { - int global_col = kt * MMA_K_BF16 + c; - float pval = s_p_vals[r * SK_TILE + global_col]; - int core_mn = r / 8, local_r = r % 8; - int core_k = c / 8, local_c = c % 8; - sPk[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] = - f32_to_bf16(pval); + int gc = kt * MMA_K_BF16 + c; + float pv = s_p_vals[r * SK_TILE + gc]; + int cmn = r / 8, lr = r % 8, ck = c / 8, lc = c % 8; + sPk[ck * CORES_MN * 64 + cmn * 64 + lr * 8 + lc] = f32_to_bf16(pv); } } for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0; for (int dd = lane; dd < 16; dd += 32) { for (int lr = 0; lr < MMA_K_BF16; lr++) { int r = kt * MMA_K_BF16 + lr; - int g_mn = dd / 8, g_k = lr / 8; - int llr = dd % 8, lc = lr % 8; - sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] = - v_head[(d_base + dd) * s_k + r]; + int gmn = dd / 8, gk = lr / 8, llr = dd % 8, lc = lr % 8; + sV[gk * 2 * 64 + gmn * 64 + llr * 8 + lc] = v_head[(d_base + dd) * s_k + r]; } } } __syncthreads(); - if (is_mma_warp) { - uint32_t idesc_pv16 = make_idesc(128, 16); + uint32_t idesc16 = make_idesc(128, 16); uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128); uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); - if (tid == 128) umma_ss_f16(tb + n * 16, dp, dv, idesc_pv16, kt > 0); + if (tid == 128) umma_ss_f16(tb + n * 16, dp, dv, idesc16, kt > 0); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); } __syncthreads(); @@ -238,41 +235,67 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { } // ================================================================ - // Epilogue (32x32b.x8, same as multihead kernel) + // EPILOGUE // ================================================================ - if (is_softmax_warp) { - float row_max = (lane < T) ? sRowMax[lane] : 0.0f; - float row_sum = (lane < T) ? sRowSum[lane] : 1.0f; - float o_vals[HD]; - - for (int n = 0; n < HD / 8; n++) { - float tmp[8]; - asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" - : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), - "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) - : "r"(tb + n * 8)); - asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (!full_tile) { + // T <= 32: warp 0, 32x32b.x8 + if (wid == 0) { + float rm = (lane < T) ? sRowMax[lane] : 0.0f; + float rs = (lane < T) ? sRowSum[lane] : 1.0f; + float o_vals[HD]; + for (int n = 0; n < HD / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane < T) for (int c = 0; c < 8; c++) o_vals[n * 8 + c] = tmp[c]; + } if (lane < T) { - for (int c = 0; c < 8; c++) o_vals[n * 8 + c] = tmp[c]; + float inv = 1.0f / rs; + for (int d = 0; d < HD; d++) o_head[lane * HD + d] = f32_to_bf16(o_vals[d] * inv); + if (lse_head) lse_head[lane] = logf(rs) + rm; } } + } else { + // T > 32: 4 warps, 16x256b.x1 + if (is_softmax_warp) { + const int warp_row_start = wid * 32; + const int warp_row_end = min((wid + 1) * 32, T); + const int lane_row_base = lane * 4; + const bool active = (lane_row_base >= warp_row_start) && (lane_row_base < warp_row_end); + const int vrows = active ? min(4, warp_row_end - lane_row_base) : 0; - if (lane < T) { - float inv_row_sum = 1.0f / row_sum; + // Read O from TMEM one column at a time for (int d = 0; d < HD; d++) { - o_head[lane * HD + d] = f32_to_bf16(o_vals[d] * inv_row_sum); + float v[4]; + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0,%1,%2,%3},[%4,%5];" + : "=f"(v[0]), "=f"(v[1]), "=f"(v[2]), "=f"(v[3]) : "r"(tb), "r"(d)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (active) { + for (int r = 0; r < vrows; r++) { + int gr = lane_row_base + r; + if (gr < T) { + float inv = 1.0f / sRowSum[gr]; + o_head[gr * HD + d] = f32_to_bf16(v[r] * inv); + } + } + } } - if (lse_head) { - lse_head[lane] = logf(row_sum) + row_max; + // LSE + if (active) { + for (int r = 0; r < vrows; r++) { + int gr = lane_row_base + r; + if (gr < T && lse_head) lse_head[gr] = logf(sRowSum[gr]) + sRowMax[gr]; + } } } } __syncthreads(); // TMEM dealloc - if (is_mma_warp) { - tmem_dealloc(tb, TMEM_N); - } + if (is_mma_warp) tmem_dealloc(tb, TMEM_N); } } // namespace dsv4::kernels::attention