stuff
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
# CURRENT_ISSUE.md — FMHA 6-Warp Specialization
|
||||
|
||||
## Status: Milestone 5 COMPLETE ✅ (multi-head grid launch with MHA/MQA/batch)
|
||||
## Status: Milestone 4 IN PROGRESS (multi-row softmax for prefill T>1)
|
||||
### Milestone 5 ✅ DONE — multi-head grid launch
|
||||
### Milestone 4: T≤32 PASSING (cos 0.999996+), T>32 BLOCKED on TMEM row read
|
||||
### CRITICAL BUG FIXED: Q/K SMEM canonical layout used full_d instead of local d (0..15)
|
||||
|
||||
### What works:
|
||||
- **6-warp kernel**: Warps 0-3 softmax/epilogue, Warp 4 MMA, Warp 5 data staging
|
||||
@@ -12,6 +15,7 @@
|
||||
- **Batched**: blockIdx.z for batch dimension
|
||||
- **LSE output**: per-row LSE for multi-segment KV merge
|
||||
- **FmhaParams struct**: stride-based tensor addressing, future-proof for GQA
|
||||
- **Multi-row softmax T≤32**: cos 0.999996+ with per-lane per-row softmax (no wmax/wsum)
|
||||
|
||||
### Architecture:
|
||||
```
|
||||
@@ -37,9 +41,15 @@ Warp 5 (tid 160-191): Data staging
|
||||
- Load next K/V while computing current QK
|
||||
- mbarrier producer-consumer sync between warp 5 and warp 4
|
||||
- Depends on TMA loads (Milestone 2)
|
||||
3. **Multi-row softmax** (Milestone 4): Process all 128 rows (prefill T>1)
|
||||
- All 4 softmax warps process rows in parallel
|
||||
- Warp w handles rows [w*32, (w+1)*32) ∩ [0, T)
|
||||
3. **Multi-row softmax** (Milestone 4): Process all 128 rows (prefill T>1) 🚧 IN PROGRESS
|
||||
- T≤32: WORKING — warp 0, lane l handles row l, 32x32b.x8 TMEM reads
|
||||
- T>32: BLOCKED — 32x32b.x8 only reads rows 0-31
|
||||
- NEXT: Use 16x256b.x1 TMEM reads (reads all 128 rows per column)
|
||||
- Each of 4 softmax warps handles rows [w*32, (w+1*32) ∩ [0, T)
|
||||
- Per-lane row assignment in 16x256b: lane j gets rows j*4+0..3
|
||||
- No cross-warp reduction needed (disjoint row sets)
|
||||
- KEY LESSON: Q/K SMEM canonical positions MUST use local d (0..15), NOT full_d
|
||||
The UMMA descriptor always reads from sQ0/sK0 start, not offset
|
||||
4. ~~**Multi-head launch** (Milestone 5): grid=(1, n_h, batch)~~ ✅ DONE
|
||||
5. **Production integration** (Milestone 6): Hook into production.py
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# ISSUE — Lightning Indexer FP4 dequant decodes E2M1 wrong
|
||||
|
||||
**Status:** FIXED ✅ — E2M1 LUT fix landed in both `dsv4/kernels/indexer/indexer_score_topk.cu` and `dsv4/kernels/cuda/indexer_score_topk.cu`.
|
||||
**Status:** FIXED ✅ — E2M1 LUT fix landed in both `dsv4/kernels/indexer/indexer_score_topk.cu` and `dsv4/kernels/cuda/indexer_score_topk.cu`. Crossed off the list.
|
||||
**Severity:** Was HIGH. Corrupts top-k *selection*, which is the whole job of the indexer.
|
||||
**Scope:** `dsv4/kernels/indexer/indexer_score_topk.cu` and the duplicate
|
||||
`dsv4/kernels/cuda/indexer_score_topk.cu`. Does NOT touch FMHA, MoE, or the GEMM stack.
|
||||
|
||||
175
PRIORITY2.md
Normal file
175
PRIORITY2.md
Normal file
@@ -0,0 +1,175 @@
|
||||
# ISSUE — Priority 2: kill the Python KV merge; one-way in-kernel epilogue + online softmax
|
||||
|
||||
**Status:** OPEN — structural shortcut on the FMHA decode hot path.
|
||||
**Severity:** HIGH for "production grade." Correctness is fine (cos 0.999998); the
|
||||
problem is this path can never be fast, and it blocks the entire multi-CTA / FP4
|
||||
fusion chain (D2/P4, NVFP4-1.2, NVFP4-2).
|
||||
**Scope:** `dsv4/kernels/attention/production.py` (`_run_fmha_segmented`),
|
||||
`dsv4/kernels/attention/fmha.py` (epilogue), the WIP `fmha_smem_acc.py`, and the
|
||||
raw-CUDA `fmha_sm100_tc.cuh`.
|
||||
|
||||
---
|
||||
|
||||
## TL;DR
|
||||
|
||||
Multi-KV-tile attention is currently done by launching the single-tile CuTeDSL
|
||||
kernel once per (segment × pv_tile), copying un-normalized O + LSE back to the
|
||||
host-side allocator, and merging tiles with **eager PyTorch ops** — with a full
|
||||
**`torch.cuda.synchronize()` inside the inner loop**. That is orchestration, not a
|
||||
kernel. The fix is the standard FlashAttention shape: **loop KV tiles inside the
|
||||
kernel with running max/sum (online softmax), and write the result once through a
|
||||
one-way TMEM→regs→SMEM→GMEM epilogue** — the exact pattern the MoE GEMM already
|
||||
runs correctly on Blackwell.
|
||||
|
||||
This is the load-bearing example of doctrine rule 1: hitting a CuTeDSL wall is not
|
||||
a license to fall back to Python.
|
||||
|
||||
---
|
||||
|
||||
## The shortcut, measured (not estimated)
|
||||
|
||||
`dsv4/kernels/attention/production.py`, `_run_fmha_segmented`:
|
||||
|
||||
```python
|
||||
for seg in range(n_segments):
|
||||
...
|
||||
k_seg = torch.cat([k_seg, torch.zeros(...)]) # alloc + copy to pad
|
||||
seg_o = torch.zeros(M, hd, ...) # per-seg allocs
|
||||
for nt in range(n_pv_tiles):
|
||||
c_tile = torch.zeros(M, pv_n_tile, 1, ...) # per-tile allocs
|
||||
lse_tensor = torch.zeros(M, 1, 1, ...)
|
||||
mQ = ct.from_dlpack(q_3d).mark_layout_dynamic(...) # descriptor rebuild
|
||||
... # every iteration
|
||||
compiled(mQ, mK, mV, mC, stream, lse=mLSE, ...) # 1 launch
|
||||
torch.cuda.synchronize() # <-- FULL DEVICE SYNC
|
||||
seg_o[:, v_start:v_end] = c_tile[...].float()
|
||||
# eager exp/log/div merge (≈5 launches)
|
||||
o_accum = (e_old*o_accum + e_new*seg_o_norm) / e_sum
|
||||
```
|
||||
|
||||
Per-call cost on the hot path, every iteration:
|
||||
- 1 kernel launch **+ 1 full `cudaDeviceSynchronize`** (serializes the stream —
|
||||
no overlap, no pipelining, no async)
|
||||
- 6+ `torch.zeros` allocations (allocator churn)
|
||||
- `torch.cat` to pad the partial tail segment (alloc + copy)
|
||||
- `ct.from_dlpack(...).mark_layout_dynamic(...)` ×6 (host-side descriptor rebuild)
|
||||
- ~5 eager merge ops per segment
|
||||
|
||||
**Launch/sync count, V4-Pro CSA decode, single token, single layer:**
|
||||
`top_k = 1024` → `n_segments = 1024/128 = 8`; `hd = 512` → `n_pv_tiles = 512/128 = 4`.
|
||||
→ **32 kernel launches + 32 device syncs** per CSA layer, plus ~40 eager merge ops.
|
||||
Across the interleaved CSA/HCA layers of a 61-layer model that is **~1–2k launches
|
||||
and ~1–2k device syncs per decoded token.** At single-digit-µs launch+sync each,
|
||||
that's milliseconds of pure overhead per token — a hard ceiling of tens of tok/s
|
||||
before any real compute runs. At 1M context (KV-read bound, the entire point of
|
||||
V4) this is the bottleneck, full stop.
|
||||
|
||||
Head-packing (MQA → all 128 Q heads in M) is already correct and keeps this from
|
||||
also multiplying by heads — good. The segment/tile loop and the per-iteration sync
|
||||
are the problem.
|
||||
|
||||
---
|
||||
|
||||
## The blocker
|
||||
|
||||
`dsv4/kernels/attention/fmha.py` (~line 620) writes O via
|
||||
`utils.gemm.sm100.epilogue_tma_store`, which reads O straight from TMEM and
|
||||
TMA-stores it. Per the file header (point 2): it **cannot accept flat_divide-based
|
||||
GMEM coordinates**, so you can't fold (batch, head, kv_tile) into a multi-CTA grid.
|
||||
That single epilogue choice is why multi-tile has to be done host-side, and why
|
||||
D2/P4 (multi-CTA) and FP4 output fusion are all stuck behind it.
|
||||
|
||||
The TMEM round-trip ("D1.5 fundamentally broken") is a **CuTeDSL** atom-pairing
|
||||
limitation, not hardware — CUTLASS C++ pairs Ld/St atoms that work. So the answer
|
||||
is not "merge in Python," it's "use the one-way epilogue (no round-trip) or write
|
||||
the raw CUDA kernel."
|
||||
|
||||
---
|
||||
|
||||
## The pattern to port FROM (already working on Blackwell)
|
||||
|
||||
`dsv4/kernels/gemm/dense.py` runs the correct one-way epilogue every MoE GEMM:
|
||||
TMEM → regs → SMEM → GMEM, no round-trip, multi-CTA-safe, with a register slot for
|
||||
fusion. Reuse these symbols (already imported into `fmha.py`):
|
||||
|
||||
- `sm100_utils.compute_epilogue_tile_shape` (line ~313) — epi subtile
|
||||
- `sm100_utils.make_smem_layout_epi` (line ~361) — SMEM staging layout
|
||||
- `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition`
|
||||
(imported at fmha.py lines 71–72, currently unused on the O path)
|
||||
- `epilogue_op: cutlass.Constexpr = lambda x: x` (line ~405) — the hook where FP4
|
||||
pack lands later (NVFP4-1.2)
|
||||
|
||||
This is not new infrastructure. It's wiring the FMHA O-store to the epilogue the
|
||||
MoE kernel already proves out.
|
||||
|
||||
---
|
||||
|
||||
## Design of the fix
|
||||
|
||||
**1. Online softmax across KV tiles, in-kernel.** Replace the host-side log-sum-exp
|
||||
merge with a running `(row_max, row_sum)` rescale inside the KV-tile loop — standard
|
||||
FlashAttention-2. The kernel already emits un-normalized O + LSE + row_sums for a
|
||||
single tile; lift the accumulation into the tile loop so one launch consumes all
|
||||
selected KV. No host merge, no per-tile sync, no per-tile allocs.
|
||||
|
||||
**2. One-way epilogue.** After the tile loop, O lives in TMEM. Copy
|
||||
TMEM→regs→SMEM→GMEM via the MoE epilogue path. No `epilogue_tma_store`, no TMEM
|
||||
round-trip. Normalize by the final `row_sum` in registers before the SMEM store
|
||||
(or keep emitting LSE if external normalization is still wanted — but it shouldn't
|
||||
be needed once the merge is in-kernel).
|
||||
|
||||
**3. Multi-CTA grid.** With the one-way epilogue accepting flat_divide coords,
|
||||
fold (batch, kv_head_group, [q_tile]) into the grid. This is D2/P4. `use_2cta_instrs`
|
||||
applies for prefill/batched shapes (M ≥ 256); decode single-token stays 1-CTA.
|
||||
|
||||
**4. Preallocated scratch.** Whatever staging remains is allocated once at warmup
|
||||
and reused (cudagraph-safe), never `torch.zeros` on the hot path.
|
||||
|
||||
### Two ramps — pick per doctrine, not per convenience
|
||||
|
||||
- **Ramp A (preferred if CuTeDSL cooperates):** finish the one-way epilogue in
|
||||
`fmha_smem_acc.py` against the MoE pattern. Status of that file is unclear ("many
|
||||
commits, unclear if working") — first action is to determine whether it works,
|
||||
with a print-and-diff, not by reading the diff log.
|
||||
- **Ramp B (when CuTeDSL walls — the legitimate fallback):** `fmha_sm100_tc.cuh`
|
||||
is already the right shape — `tcgen05.mma` SS for QK, TS for PV, TMEM accumulators,
|
||||
UMMA descriptors, and it already carries `sRowMax/sRowSum` running state for
|
||||
in-kernel online softmax. Finish THIS, not a Python merge. It is raw CUDA C++,
|
||||
it is Blackwell-native, and it has no `epilogue_tma_store` constraint because you
|
||||
control the store.
|
||||
|
||||
---
|
||||
|
||||
## Test plan — measure, don't eyeball
|
||||
|
||||
1. **Correctness parity:** new in-kernel path vs current Python-merge path on the
|
||||
same inputs across the matrix that already passes — hd ∈ {64,128,256,512},
|
||||
n_segments ∈ {1,2,8}, with SWA mask / causal / sink / n_comp. Gate: cos ≥ 0.999998
|
||||
(must match the path it replaces, this is a refactor not a numerics change).
|
||||
2. **Launch + sync count:** Nsight Systems (or a CUPTI launch counter) on one
|
||||
decoded token, V4-Pro CSA layer. Record launches and `cudaDeviceSynchronize`
|
||||
calls before/after. Target: per-layer launches from 32 → 1 (decode), syncs from
|
||||
32 → 0 on the hot path.
|
||||
3. **Latency:** per-token decode latency at context 8k / 128k / 1M, before/after.
|
||||
This is the number that says whether Priority 1's "does the overhead matter"
|
||||
question is answered. (Spoiler from the launch math: it matters.)
|
||||
4. **Unblock check:** confirm the new epilogue accepts a multi-CTA grid (D2 smoke
|
||||
test, M≥256 prefill) and exposes a register slot for the FP4 pack (NVFP4-1.2
|
||||
stub: `epilogue_op` that amax+packs, behind a flag, off by default).
|
||||
|
||||
---
|
||||
|
||||
## DOCTRINE REMINDER (this issue is the cautionary tale)
|
||||
|
||||
1. **CuTeDSL/CUTLASS wall → raw CUDA C++, NOT Python.** The Python KV merge is what
|
||||
rule 1 exists to prevent. `fmha_sm100_tc.cuh` is the correct fallback. A host-side
|
||||
loop with a per-iteration `cudaDeviceSynchronize` is never the production answer.
|
||||
2. **Raw CUDA ≠ scalar math.** The fallback kernel stays `tcgen05`/UMMA/TMEM/TMA with
|
||||
warp-level softmax reductions. Do not let an agent "simplify" the PV into scalar
|
||||
FMA to get it compiling.
|
||||
3. **Print, don't guess.** Before touching `fmha_smem_acc.py`, print its actual
|
||||
output and diff against the reference — do not infer its status from commit
|
||||
messages. When wiring the epilogue, print epi tile shape, SMEM layout, TMEM
|
||||
offsets, and the MMA instruction shape at construction, and code to those values.
|
||||
The epilogue is exactly the kind of layout surface where a guess passes a toy
|
||||
test and corrupts at real shapes.
|
||||
@@ -1,15 +1,11 @@
|
||||
/**
|
||||
* DSV4 FMHA — 6-warp specialized kernel, multi-row softmax (prefill T>1).
|
||||
*
|
||||
* ==================================================================
|
||||
* MULTI-ROW SOFTMAX (Milestone 4)
|
||||
* ==================================================================
|
||||
* For T ≤ 32: Uses 32x32b.x8 TMEM reads (same as decode).
|
||||
* Each lane handles one row. Warp 0 processes rows 0..T-1.
|
||||
* No cross-lane reduction needed (per-row max/sum in each lane).
|
||||
*
|
||||
* For T > 32: NOT YET IMPLEMENTED (will use 16x256b.x1).
|
||||
* ==================================================================
|
||||
* T <= 32: Warp 0 only, 32x32b.x8, lane l = row l.
|
||||
* T > 32: 4 softmax warps, 16x256b.x1, two-pass online softmax.
|
||||
* Warp w handles rows [w*32, min((w+1)*32, T)).
|
||||
* Lane j reads rows j*4+0..3 per TMEM column.
|
||||
* Active lanes for warp w: j in [w*8, w*8+8).
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@@ -25,11 +21,9 @@ struct FmhaMultiRowParams {
|
||||
const bf16_t* __restrict__ v;
|
||||
bf16_t* __restrict__ o;
|
||||
float* __restrict__ lse;
|
||||
|
||||
int s_k, T;
|
||||
float scale;
|
||||
int head_dim;
|
||||
|
||||
int q_head_stride, q_batch_stride;
|
||||
int k_head_stride, k_batch_stride;
|
||||
int v_head_stride, v_batch_stride;
|
||||
@@ -53,88 +47,62 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) {
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane = tid % 32;
|
||||
|
||||
const bool is_softmax_warp = (wid == 0);
|
||||
const bool is_softmax_warp = (wid < 4);
|
||||
const bool is_mma_warp = (wid == 4);
|
||||
const bool is_load_warp = (wid == 5);
|
||||
|
||||
const int T = params.T;
|
||||
const int s_k = params.s_k;
|
||||
const float scale = params.scale;
|
||||
const bool full_tile = (T > 32);
|
||||
|
||||
// ==================================================================
|
||||
// Per-head GMEM pointers
|
||||
// ==================================================================
|
||||
const bf16_t* __restrict__ q_head = params.q
|
||||
+ head_idx * params.q_head_stride
|
||||
+ batch_idx * params.q_batch_stride;
|
||||
const bf16_t* __restrict__ k_head = params.k
|
||||
+ head_idx * params.k_head_stride
|
||||
+ batch_idx * params.k_batch_stride;
|
||||
const bf16_t* __restrict__ v_head = params.v
|
||||
+ head_idx * params.v_head_stride
|
||||
+ batch_idx * params.v_batch_stride;
|
||||
bf16_t* __restrict__ o_head = params.o
|
||||
+ head_idx * params.o_head_stride
|
||||
+ batch_idx * params.o_batch_stride;
|
||||
float* __restrict__ lse_head = params.lse
|
||||
? params.lse + head_idx * params.lse_head_stride
|
||||
+ batch_idx * params.lse_batch_stride
|
||||
: nullptr;
|
||||
const bf16_t* __restrict__ q_head = params.q + head_idx * params.q_head_stride + batch_idx * params.q_batch_stride;
|
||||
const bf16_t* __restrict__ k_head = params.k + head_idx * params.k_head_stride + batch_idx * params.k_batch_stride;
|
||||
const bf16_t* __restrict__ v_head = params.v + head_idx * params.v_head_stride + batch_idx * params.v_batch_stride;
|
||||
bf16_t* __restrict__ o_head = params.o + head_idx * params.o_head_stride + batch_idx * params.o_batch_stride;
|
||||
float* __restrict__ lse_head = params.lse ? params.lse + head_idx * params.lse_head_stride + batch_idx * params.lse_batch_stride : nullptr;
|
||||
|
||||
// ================================================================
|
||||
// SMEM allocation — SAME LAYOUT as multihead kernel for T=1 compat
|
||||
// ================================================================
|
||||
// SMEM
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
float* sRowMax = (float*)(sbuf + 4); // [32] per-warp rows
|
||||
float* sRowSum = sRowMax + 32; // [32]
|
||||
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowSum + 32) + 15) & ~(uintptr_t)15);
|
||||
float* sRowMax = (float*)(sbuf + 4); // [MAX_ROWS]
|
||||
float* sRowSum = sRowMax + MAX_ROWS; // [MAX_ROWS]
|
||||
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowSum + MAX_ROWS) + 15) & ~(uintptr_t)15);
|
||||
bf16_t* sK0 = sQ0 + TILE_SZ;
|
||||
bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127);
|
||||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127);
|
||||
float* s_p_vals = (float*)(sV + V_SUB_SZ); // [MAX_ROWS][SK_TILE]
|
||||
float* s_p_vals = (float*)(sV + V_SUB_SZ); // [MAX_ROWS][SK_TILE]
|
||||
|
||||
// ================================================================
|
||||
// TMEM allocation
|
||||
// ================================================================
|
||||
if (is_mma_warp) {
|
||||
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
|
||||
tmem_alloc(smem_ptr, TMEM_N);
|
||||
}
|
||||
// TMEM alloc
|
||||
if (is_mma_warp) { uint32_t p = __cvta_generic_to_shared(sTmemBase); tmem_alloc(p, TMEM_N); }
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// ================================================================
|
||||
// QK GEMM loop
|
||||
// QK GEMM
|
||||
// ================================================================
|
||||
for (int kt = 0; kt < NKT_QK; kt++) {
|
||||
if (is_load_warp) {
|
||||
constexpr int CORES_MN = 128 / 8;
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0;
|
||||
// Q loading: write to same canonical positions for each K-tile
|
||||
// The UMMA descriptor always reads from sQ0 start
|
||||
for (int r = 0; r < T; r++) {
|
||||
for (int d = lane; d < MMA_K_BF16; d += 32) {
|
||||
int full_d = kt * MMA_K_BF16 + d; // GMEM index
|
||||
int ck = d / 8, lc = d % 8; // canonical position (same for all kt)
|
||||
int full_d = kt * MMA_K_BF16 + d;
|
||||
int ck = d / 8, lc = d % 8;
|
||||
int core_mn = r / 8, local_r = r % 8;
|
||||
sQ0[ck * CORES_MN * 64 + core_mn * 64 + local_r * 8 + lc] =
|
||||
q_head[r * HD + full_d];
|
||||
sQ0[ck * CORES_MN * 64 + core_mn * 64 + local_r * 8 + lc] = q_head[r * HD + full_d];
|
||||
}
|
||||
}
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0;
|
||||
for (int r = 0; r < s_k; r++) {
|
||||
for (int d = lane; d < MMA_K_BF16; d += 32) {
|
||||
int full_d = kt * MMA_K_BF16 + d; // GMEM index
|
||||
int ck = d / 8, lc = d % 8; // canonical position (same for all kt)
|
||||
int full_d = kt * MMA_K_BF16 + d;
|
||||
int ck = d / 8, lc = d % 8;
|
||||
int tmn = r / 8, lr = r % 8;
|
||||
sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k_head[r * HD + full_d];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (is_mma_warp) {
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
|
||||
@@ -146,56 +114,90 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) {
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Softmax (32x32b.x8, per-lane per-row)
|
||||
// SOFTMAX
|
||||
// ================================================================
|
||||
// For T ≤ 32: lane l handles row l. Lanes l >= T have no data.
|
||||
// No wmax/wsum — each lane computes its own row independently.
|
||||
// ================================================================
|
||||
if (is_softmax_warp) {
|
||||
float s_vals[SK_TILE], row_max = -INFINITY;
|
||||
|
||||
// Read S from TMEM
|
||||
for (int n = 0; n < SK_TILE / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n * 8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane < T) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
s_vals[n * 8 + c] = tmp[c] * scale;
|
||||
row_max = fmaxf(row_max, s_vals[n * 8 + c]);
|
||||
if (!full_tile) {
|
||||
// T <= 32: warp 0, 32x32b.x8, lane l = row l
|
||||
if (wid == 0) {
|
||||
float s_vals[SK_TILE], row_max = -INFINITY;
|
||||
for (int n = 0; n < SK_TILE / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n * 8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane < T) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
s_vals[n * 8 + c] = tmp[c] * scale;
|
||||
row_max = fmaxf(row_max, s_vals[n * 8 + c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store per-row max
|
||||
if (lane < T) sRowMax[lane] = row_max;
|
||||
|
||||
float row_sum = 0.0f;
|
||||
if (lane < T) {
|
||||
for (int j = 0; j < SK_TILE; j++) {
|
||||
s_vals[j] = expf(s_vals[j] - row_max);
|
||||
row_sum += s_vals[j];
|
||||
if (lane < T) sRowMax[lane] = row_max;
|
||||
float row_sum = 0.0f;
|
||||
if (lane < T) {
|
||||
for (int j = 0; j < SK_TILE; j++) { s_vals[j] = expf(s_vals[j] - row_max); row_sum += s_vals[j]; }
|
||||
}
|
||||
if (lane < T) sRowSum[lane] = row_sum;
|
||||
if (lane < T) {
|
||||
float inv = 1.0f / row_sum;
|
||||
for (int j = 0; j < SK_TILE; j++) s_p_vals[lane * SK_TILE + j] = s_vals[j] * inv;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// T > 32: 4 warps, 16x256b.x1, two-pass softmax
|
||||
if (is_softmax_warp) {
|
||||
const int warp_row_start = wid * 32;
|
||||
const int warp_row_end = min((wid + 1) * 32, T);
|
||||
const int lane_row_base = lane * 4;
|
||||
const bool active = (lane_row_base >= warp_row_start) && (lane_row_base < warp_row_end);
|
||||
const int vrows = active ? min(4, warp_row_end - lane_row_base) : 0;
|
||||
|
||||
// Store per-row sum
|
||||
if (lane < T) sRowSum[lane] = row_sum;
|
||||
float rmax[4], rsum[4];
|
||||
for (int r = 0; r < 4; r++) rmax[r] = -INFINITY;
|
||||
|
||||
// Normalize and write P to s_p_vals
|
||||
if (lane < T) {
|
||||
float inv_sum = 1.0f / row_sum;
|
||||
for (int j = 0; j < SK_TILE; j++) {
|
||||
s_p_vals[lane * SK_TILE + j] = s_vals[j] * inv_sum;
|
||||
// Pass 1: row_max
|
||||
for (int col = 0; col < SK_TILE; col++) {
|
||||
float v[4];
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0,%1,%2,%3},[%4,%5];"
|
||||
: "=f"(v[0]), "=f"(v[1]), "=f"(v[2]), "=f"(v[3]) : "r"(tb), "r"(col));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (active) for (int r = 0; r < vrows; r++) rmax[r] = fmaxf(rmax[r], v[r] * scale);
|
||||
}
|
||||
|
||||
// Pass 2: exp, sum, write P
|
||||
for (int r = 0; r < 4; r++) rsum[r] = 0.0f;
|
||||
for (int col = 0; col < SK_TILE; col++) {
|
||||
float v[4];
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0,%1,%2,%3},[%4,%5];"
|
||||
: "=f"(v[0]), "=f"(v[1]), "=f"(v[2]), "=f"(v[3]) : "r"(tb), "r"(col));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (active) {
|
||||
for (int r = 0; r < vrows; r++) {
|
||||
int gr = lane_row_base + r;
|
||||
float p = expf(v[r] * scale - rmax[r]);
|
||||
rsum[r] += p;
|
||||
s_p_vals[gr * SK_TILE + col] = p;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Normalize and store max/sum
|
||||
if (active) {
|
||||
for (int r = 0; r < vrows; r++) {
|
||||
int gr = lane_row_base + r;
|
||||
sRowMax[gr] = rmax[r];
|
||||
sRowSum[gr] = rsum[r];
|
||||
float inv = 1.0f / rsum[r];
|
||||
for (int j = 0; j < SK_TILE; j++) s_p_vals[gr * SK_TILE + j] *= inv;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ================================================================
|
||||
// PV GEMM loop
|
||||
// PV GEMM
|
||||
// ================================================================
|
||||
for (int n = 0; n < N_NSUB; n++) {
|
||||
int d_base = n * 16;
|
||||
@@ -205,32 +207,27 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) {
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0;
|
||||
for (int r = 0; r < T; r++) {
|
||||
for (int c = lane; c < MMA_K_BF16; c += 32) {
|
||||
int global_col = kt * MMA_K_BF16 + c;
|
||||
float pval = s_p_vals[r * SK_TILE + global_col];
|
||||
int core_mn = r / 8, local_r = r % 8;
|
||||
int core_k = c / 8, local_c = c % 8;
|
||||
sPk[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] =
|
||||
f32_to_bf16(pval);
|
||||
int gc = kt * MMA_K_BF16 + c;
|
||||
float pv = s_p_vals[r * SK_TILE + gc];
|
||||
int cmn = r / 8, lr = r % 8, ck = c / 8, lc = c % 8;
|
||||
sPk[ck * CORES_MN * 64 + cmn * 64 + lr * 8 + lc] = f32_to_bf16(pv);
|
||||
}
|
||||
}
|
||||
for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0;
|
||||
for (int dd = lane; dd < 16; dd += 32) {
|
||||
for (int lr = 0; lr < MMA_K_BF16; lr++) {
|
||||
int r = kt * MMA_K_BF16 + lr;
|
||||
int g_mn = dd / 8, g_k = lr / 8;
|
||||
int llr = dd % 8, lc = lr % 8;
|
||||
sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] =
|
||||
v_head[(d_base + dd) * s_k + r];
|
||||
int gmn = dd / 8, gk = lr / 8, llr = dd % 8, lc = lr % 8;
|
||||
sV[gk * 2 * 64 + gmn * 64 + llr * 8 + lc] = v_head[(d_base + dd) * s_k + r];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (is_mma_warp) {
|
||||
uint32_t idesc_pv16 = make_idesc(128, 16);
|
||||
uint32_t idesc16 = make_idesc(128, 16);
|
||||
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128);
|
||||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||||
if (tid == 128) umma_ss_f16(tb + n * 16, dp, dv, idesc_pv16, kt > 0);
|
||||
if (tid == 128) umma_ss_f16(tb + n * 16, dp, dv, idesc16, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -238,41 +235,67 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) {
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Epilogue (32x32b.x8, same as multihead kernel)
|
||||
// EPILOGUE
|
||||
// ================================================================
|
||||
if (is_softmax_warp) {
|
||||
float row_max = (lane < T) ? sRowMax[lane] : 0.0f;
|
||||
float row_sum = (lane < T) ? sRowSum[lane] : 1.0f;
|
||||
float o_vals[HD];
|
||||
|
||||
for (int n = 0; n < HD / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n * 8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (!full_tile) {
|
||||
// T <= 32: warp 0, 32x32b.x8
|
||||
if (wid == 0) {
|
||||
float rm = (lane < T) ? sRowMax[lane] : 0.0f;
|
||||
float rs = (lane < T) ? sRowSum[lane] : 1.0f;
|
||||
float o_vals[HD];
|
||||
for (int n = 0; n < HD / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n * 8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane < T) for (int c = 0; c < 8; c++) o_vals[n * 8 + c] = tmp[c];
|
||||
}
|
||||
if (lane < T) {
|
||||
for (int c = 0; c < 8; c++) o_vals[n * 8 + c] = tmp[c];
|
||||
float inv = 1.0f / rs;
|
||||
for (int d = 0; d < HD; d++) o_head[lane * HD + d] = f32_to_bf16(o_vals[d] * inv);
|
||||
if (lse_head) lse_head[lane] = logf(rs) + rm;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// T > 32: 4 warps, 16x256b.x1
|
||||
if (is_softmax_warp) {
|
||||
const int warp_row_start = wid * 32;
|
||||
const int warp_row_end = min((wid + 1) * 32, T);
|
||||
const int lane_row_base = lane * 4;
|
||||
const bool active = (lane_row_base >= warp_row_start) && (lane_row_base < warp_row_end);
|
||||
const int vrows = active ? min(4, warp_row_end - lane_row_base) : 0;
|
||||
|
||||
if (lane < T) {
|
||||
float inv_row_sum = 1.0f / row_sum;
|
||||
// Read O from TMEM one column at a time
|
||||
for (int d = 0; d < HD; d++) {
|
||||
o_head[lane * HD + d] = f32_to_bf16(o_vals[d] * inv_row_sum);
|
||||
float v[4];
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0,%1,%2,%3},[%4,%5];"
|
||||
: "=f"(v[0]), "=f"(v[1]), "=f"(v[2]), "=f"(v[3]) : "r"(tb), "r"(d));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (active) {
|
||||
for (int r = 0; r < vrows; r++) {
|
||||
int gr = lane_row_base + r;
|
||||
if (gr < T) {
|
||||
float inv = 1.0f / sRowSum[gr];
|
||||
o_head[gr * HD + d] = f32_to_bf16(v[r] * inv);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (lse_head) {
|
||||
lse_head[lane] = logf(row_sum) + row_max;
|
||||
// LSE
|
||||
if (active) {
|
||||
for (int r = 0; r < vrows; r++) {
|
||||
int gr = lane_row_base + r;
|
||||
if (gr < T && lse_head) lse_head[gr] = logf(sRowSum[gr]) + sRowMax[gr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// TMEM dealloc
|
||||
if (is_mma_warp) {
|
||||
tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
if (is_mma_warp) tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
|
||||
} // namespace dsv4::kernels::attention
|
||||
|
||||
Reference in New Issue
Block a user