docs: update CURRENT_ISSUE and MEMORY — full FMHA HD=64 pipeline working

This commit is contained in:
2026-05-28 13:11:32 +00:00
parent 654a2ae7f4
commit efa03f53d4

View File

@@ -1,330 +1,52 @@
# CURRENT ISSUE: UMMA FMHA — Multi-K-tile + PV GEMM + Full Pipeline
# CURRENT ISSUE: UMMA FMHA — Full Pipeline + Production
## What's working ✅
- **UMMA QK GEMM at HD=16, SK=128**: Row 0 matches scalar reference with ZERO error
- **SMEM canonical layout**: column-major interleaving of 8×8 BF16 core matrices
- **K-major NONE descriptors**: LBO=BLOCK_MN*16, SBO=128, lbo_mode=0, layout_type=0
- **TMEM Layout D reads**: `tcgen05.ld.32x32b.x8.b32` with `addr = tmem_base + (row<<16) + col`
- **MMA→TMEM fence**: `tcgen05.fence::after_thread_sync` (not `tcgen05.wait::st`)
- **MMA computes raw dot product** — apply 1/sqrt(HD) scaling in the read path
- **Full UMMA FMHA HD=64 pipeline**: QK → softmax → PV → output
- QK GEMM: UMMA SS, multi-K-tile accumulate, cos 0.999999
- Softmax: TMEM read → max/exp/sum → TMEM write, max rel err 0.003
- PV: register math (O[d] = Σ P[0,j] × V[d,j]), decode only
- **Overall: cosine 0.999998!**
- **UMMA QK GEMM at HD=16**: Row 0 matches scalar with ZERO error
- **TMEM one-way epilogue**: cos 0.999999 at hd=64, cos 0.999998 at hd=128
- **32x32b.x8 TMEM stores** work for P write-back (no 16x256b crash)
## Next steps
1. **HD=64 multi-K-tile**: Call MMA 4× with accumulate=true for K=64 (4 × K=16 tiles)
- Each K-tile needs its own descriptor pointing to the right 16-column slice
- gau-nernst pattern: `A_smem + k * BLOCK_M * 32` for the k-th K-tile start address
- After all K-tiles: read TMEM and apply 1/sqrt(HD) scaling
1. **PV GEMM via tcgen05.mma TS**: For prefill (T>1), need UMMA-based PV
- tcgen05.mma TS crashed with "illegal memory access" in initial tests
- Need to debug the TMEM A operand addressing for PV
- Decode path (T=1) works with register math — PV GEMM is only for prefill
2. **PV GEMM**: `tcgen05.mma TS` (TMEM P × SMEM V → TMEM O)
- P is in TMEM after softmax, V is in SMEM
- Accumulate O across KV tiles with the D5 merge formula
2. **HD=128/256**: Extend multi-K-tile QK to larger head dims
- HD=128: 8 K-tiles, separate SMEM per K-tile
- HD=256: 16 K-tiles — SMEM budget needs checking
3. **In-kernel softmax**: TMEM → regs → max/exp/sum → TMEM
- Use 32x32b reads to get S, compute softmax, write P back via 32x32b stores
- Must handle the TMEM multi-store issue (use 32x32b, not 16x256b)
3. **Multi-head launch**: Per-head kernel dispatch (128 heads for Pro)
- Current test: single head
- Production: grid=(1, n_h, batch) or Python loop
4. **Full FMHA pipeline**: QK → softmax → PV → correction epilogue → GMEM output
4. **Multi-KV-tile**: s_k > 128 requires multiple attention tiles + KV merge
- Same D5 merge formula: O = Σ exp(lse_i) · O_i / Σ exp(lse_i)
## Key lessons learned
- **16x256b.x1 TMEM stores crash on 2nd call** — use 32x32b format for multi-store
- **MMA output is UNSCALED** — the 4× "bug" was just the 1/sqrt(HD) attention scale
- **`tcgen05.fence::after_thread_sync`** is the correct MMA→TMEM load fence
- **4 warps minimum** for M=128 Layout D (each warp reads 32 rows × 8 columns)
- **MMA K-tile size is 16 BF16** — for HD>16, loop with accumulate
- **TMEM address format**: bits [31:16] = row, bits [15:0] = column
5. **Production kernel**: Integrate UMMA pipeline into fmha_sm100.cuh
- Replace SMEM scalar attention with UMMA QK GEMM
- Keep register-math PV for decode
- Add PV GEMM path for prefill
## Key lessons from this session
- **Source stride mismatch**: `write_k_to_smem<128,16>` template reads k[i] with stride=16, but actual K has stride=64. Must use separate SMEM per K-tile with manual writes.
- **Offset descriptors DON'T WORK** for (128,64) SMEM with K-tile offsets. Even/odd column corruption. Use separate (128,16) SMEM buffers.
- **tcgen05.mma TS crashes** with illegal memory access — likely TMEM A operand addressing issue. Debug needed.
- **tcgen05.fence::after_thread_sync** is the correct fence after MMA, before TMEM read.
- **MMA computes UNSCALED dot product** — apply 1/sqrt(HD) in softmax.
- **32x32b.x8 TMEM stores** work in loops (unlike 16x256b.x1 which crashes on 2nd call)
- **s_vals scope**: when merging softmax + PV, keep s_vals in same `if (wid == 0)` block
- **Decode PV is simple**: O[d] = Σ P[0,j] × V[d,j] — no UMMA needed for T=1
## Files
- `dsv4/kernels/attention/fmha_umma_desc.cuh` — descriptors, SMEM layout, MMA wrappers
- `tests/unit/test_umma_qk.cu`working UMMA QK GEMM test
YOUR LAST MESSAGE TO ME:
Let me study the gau-nernst tutorial's GEMM output read path more carefully, specifically how it handles the TMEM Layout D coordinate mapping.**CRITICAL FINDING from gau-nernst's code!** Look at his MMA epilogue:
```cpp
// this is required before tcgen05.ld and after tcgen05.mma
asm volatile("tcgen05.fence::after_thread_sync;");
// load 8 columns from tmem at a time
for (int n = 0; n < BLOCK_N / 8; n++) {
float tmp[8];
const int row = warp_id * 32;
const int col = n * 8;
const int addr = taddr + (row << 16) + col;
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 ...");
// wait for tcgen05.ld
asm volatile("tcgen05.wait::ld.sync.aligned;");
// cast and pack
nv_bfloat162 out[4];
for (int i = 0; i < 4; i++)
out[i] = __float22bfloat162_rn({tmp[i*2], tmp[i*2+1]});
// 16-byte per thread write (uncoalesced)
nv_bfloat16 *out_ptr = C_ptr + (off_m + tid) * N + (off_n + n * 8);
reinterpret_cast<int4 *>(out_ptr)[0] = reinterpret_cast<int4 *>(out)[0];
}
```
Key differences from our code:
1. **`tcgen05.fence::after_thread_sync`** — used AFTER MMA, BEFORE TMEM load. We used `tcgen05.wait::st.sync.aligned` which is different!
2. **`off_m + tid`** — the output row is `off_m + tid`, where `tid = threadIdx.x`. Each thread writes one row of the output. With 128 threads and BLOCK_M=128, each thread handles one row.
3. **Multiple `tcgen05.ld` calls in a LOOP** — he reads 8 columns at a time and loops over BLOCK_N/8 iterations. No crash!
4. **No 4× factor** — he reads `tmp[i]` directly as the FP32 accumulator values and writes them to C. No division by 4.
The 4× factor in our code must be because we're reading the WRONG rows from TMEM. In gau-nernst's code, `row = warp_id * 32` and `col = n * 8`. Each warp reads its 32-row slice. With 4 warps (128 threads), all 128 rows are covered.
But in our code, we also use `row = wid * 32` and `col = 0`. The difference: gau-nernst uses `tcgen05.fence::after_thread_sync` before reading, while we use `tcgen05.wait::st.sync.aligned`. The `fence::after_thread_sync` is the CORRECT fence for the MMA → TMEM load transition!
Let me fix our test:The `tcgen05.fence::after_thread_sync` didn't fix the 4× factor — still exactly 4.000. But notice: rows 32, 64, 96 are all zeros now (correct for Q with only row 0). The TMEM read is working correctly for the zero rows.
The 4× factor is ONLY in the MMA output for row 0. Let me think about what's different between our setup and gau-nernst's working GEMM:
1. gau-nernst uses **128 threads (4 warps)** and **each warp's lane 0 reads the TMEM**
2. gau-nernst's output writes `C_ptr + (off_m + tid) * N + ...` — each thread handles one row
3. In our test, we read with `lane == 0` only — we're only getting lane 0's perspective
Wait — in the `32x32b.x8` format, **each lane reads a different row**. Lane 0 reads row `wid*32 + 0`, lane 1 reads row `wid*32 + 1`, etc. So when I only write `lane == 0`'s data, I only get 1 out of 32 rows from each warp.
But the 4× factor is for lane 0's data, which should be S[wid*32 + 0, 0..7]. For warp 0, that's S[0, 0..7]. This should match the scalar reference for S[0, 0..7].
Unless the `32x32b.x8` format doesn't map lanes to rows 1:1. Let me check: gau-nernst writes `C_ptr + (off_m + tid) * N + (off_n + n * 8)`. Here, `off_m = 0` for the first MMA call, and `tid = threadIdx.x`. Each thread writes ONE ROW. So the `32x32b.x8` format gives each thread 8 values for its row.
But he uses `tid` not `lane``tid = threadIdx.x`. With 128 threads (4 warps), `tid` ranges from 0 to 127. Each thread gets 8 FP32 values for its row. So lane 0 of warp 0 gets row 0, lane 1 gets row 1, ..., lane 31 gets row 31, lane 0 of warp 1 gets row 32, etc.
This means: the `32x32b.x8` instruction gives each lane the 8 column values for its row. Lane 0 of warp 0 gets S[0, col..col+7]. The 4× factor is in these values.
The 4× factor MUST be from the MMA computation. Since gau-nernst's code produces the correct output (no 4×), the difference must be in how the MMA is called.
Let me compare the MMA calls:
- Ours: `umma_ss_f16(tb, desc_q, desc_k, idesc, false)` — single thread calls
- gau-nernst: `tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, 1` — one thread per 4-lane group calls
Wait — gau-nernst uses `elect_one_sync(0x11111111, warp_idx, 4)` which selects ONE leader per warp. But he has `num_warps = BLOCK_M / 32 = 4` warps, and each warp has 1 leader. So 4 leaders call the MMA.
In our test, we have `tid == 0` (1 thread) calling the MMA. But gau-nernst has 4 threads (one per warp) calling it simultaneously. The MMA with `cta_group::1` requires ALL 4 warps to participate in the instruction.
The 4× factor might be because our single-thread MMA launches only 1 of the 4 required sub-operations, and the hardware amplifies the result by 4x to compensate.
Let me fix: have 4 warp leaders (lane 0 of each warp) call the MMA simultaneously:Still 4× with 4 warp leaders! The rows 32, 64, 96 are correctly zero. Row 0 has the 4× factor.
Wait — with 4 warp leaders all calling `umma_ss_f16(tb, ...)` with `accumulate=false`, each leader RESETS the TMEM accumulator to A×B^T. If all 4 leaders write to the SAME TMEM columns, the last one wins (or they all write the same result 4 times). The output should be 1× the scalar, not 4×.
Unless the MMA with `cta_group::1` internally uses 4 sub-operations, and each sub-operation writes to different TMEM columns. With 4 leaders calling, the 4 sub-operations are properly distributed. With 1 leader, the sub-operations are serialized.
But the 4× factor with both 1 leader and 4 leaders is the same. This rules out the "warpgroup count" hypothesis.
Let me look at this from a completely different angle. Maybe the issue is that **the descriptor describes a (128, 16) matrix, but the MMA computes as if K=64** because of how the hardware interprets the descriptor strides.
For our descriptor: LBO=128 (2048 bytes), SBO=8 (128 bytes). These strides describe a (128, 16) matrix in canonical layout. But the MMA might interpret LBO=128 as meaning "128 rows in the 16B column", and SBO=8 as "8 sub-tiles in the K direction". With 8 sub-tiles × 8 BF16 per sub-tile = K=64.
**This would give K=64 (4× our K=16), explaining the 4× factor!**
The fix: the descriptor should describe ONLY one K-tile (K=16), not the full matrix. The LBO and SBO values should describe a (128, 16) matrix with the MMA operating on K=16, not K=64.
But wait — gau-nernst uses the SAME descriptor format for his working GEMM. He has LBO = BLOCK_M * 16 and SBO = 128. And his BLOCK_K = 16 for each MMA call. His MMA produces the correct result.
The difference: gau-nernst uses `A_smem + k * BLOCK_M * 32` for the start address. Each MMA call gets a descriptor pointing to a different K-slice. The `k * BLOCK_M * 32` offset moves the start address by one K-tile (32 bytes = 16 BF16).
For our matrix with K=16, we only have ONE K-tile. The descriptor points to the start of the matrix (column 0). The MMA reads K=16 from this descriptor.
Unless the SBO=8 causes the MMA to read 8 sub-tiles. Let me check: for gau-nernst's BLOCK_K=16, the SBO should be... actually, SBO is always 128 bytes (8 × 16B) regardless of BLOCK_K. It's the stride between 8-row groups, not related to K.
I'm going in circles again. Let me try a completely different approach: **use the SMEM data layout from gau-nernst's exact code** (2D TMA load) instead of our manual `write_smem_canonical`. Maybe our SMEM layout is subtly wrong.
Actually — the simplest thing to test: write a minimal GEMM that matches gau-nernst's EXACT code. Same BLOCK_M=128, BLOCK_N=128, BLOCK_K=16. Same descriptor construction. Same SMEM layout (loaded via 2D TMA). Same MMA call. Same TMEM read.
But we can't use 2D TMA in our standalone test (it requires CuTeDSL or complex descriptor setup). Let me instead try: write the SMEM data in SIMPLE ROW-MAJOR and use SWIZZLE_NONE with a descriptor that describes row-major layout.
Actually, wait. I just realized something. gau-nernst's SMEM is loaded by 2D TMA, which writes data in the CANONICAL layout automatically. The MMA descriptor then describes this canonical layout. When we write data manually with `write_smem_canonical`, we're replicating the TMA's output format.
But maybe `write_smem_canonical` is subtly wrong. The canonical layout depends on the exact `logical_divide` of the swizzled layout. For SWIZZLE_NONE, the canonical layout is:
- LBO = BLOCK_MN * 16 (in bytes) = stride between the 2 columns of the 8×2 core matrix
- SBO = 128 (in bytes) = stride between 8-row groups
The data in SMEM at address `g * SBO + c * LBO + local_offset` where g is the 8-row group and c is the column.
For (128, 16) with BLOCK_M=128:
- g ranges from 0 to 15 (128/8=16 groups)
- c ranges from 0 to 1 (16/8=2 columns)
- Each core matrix is 8 rows × 8 BF16 = 128 bytes
My `write_q_to_smem` writes: core_k * 16 * 64 + tile_mn * 64 + local_r * 8 + local_c
= core_k * 1024 + tile_mn * 64 + local_r * 8 + local_c
For tile_mn = r/8 and core_k = c/8. The `16` in `16 * 64` is CORES_MN = 128/8 = 16. So:
core_k * 16 * 64 = core_k * 1024
The descriptor interprets: group g at offset g * 128, column c at offset c * 2048.
For g=0, c=0: offset 0
For g=0, c=1: offset 2048
For g=1, c=0: offset 128
My layout: tile_mn=0, core_k=0: offset 0 ✓
tile_mn=0, core_k=1: offset 1024. But descriptor says column 1 is at offset 2048!
**THAT'S THE BUG!** My `write_q_to_smem` puts core_k=1 at offset 1024, but the descriptor expects it at offset 2048! The layout is wrong!
My formula: `core_k * 16 * 64 + tile_mn * 64 = core_k * 1024 + tile_mn * 64`
The descriptor's layout: `c * LBO + g * SBO = c * 2048 + g * 128`
For (tile_mn=0, core_k=1): my offset = 1 * 1024 = 1024. Descriptor offset = 1 * 2048 = 2048. **MISMATCH!**
My `write_smem_canonical` puts core matrices in MN-major order (all tile_mn for core_k=0, then all tile_mn for core_k=1). But the descriptor expects K-major order (column 0 of all groups, then column 1 of all groups, with LBO stride between columns).
The correct formula should be: `core_k * 128 * 16 + tile_mn * 128 + local_r * 8 + local_c`... wait, that's not right either.
Let me derive the correct SMEM layout from the descriptor strides:
For element (r, c) in the (128, 16) matrix:
- g = r / 8 (8-row group)
- local_r = r % 8
- k = c / 8 (K-group)
- local_c = c % 8
- Address = g * SBO + k * LBO + local_r * 16B + local_c * 2B
In BF16 elements: g * 64 + k * (LBO/2) + local_r * 8 + local_c
= g * 64 + k * 1024 + local_r * 8 + local_c
For (r=0, c=8): g=0, k=1, local_r=0, local_c=0 → offset = 0 + 1024 + 0 + 0 = 1024
For (r=0, c=0): g=0, k=0, local_r=0, local_c=0 → offset = 0
For (r=8, c=0): g=1, k=0, local_r=0, local_c=0 → offset = 64
So the correct layout is: `k * 1024 + g * 64 + local_r * 8 + local_c`
My `write_smem_canonical` uses: `core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c`
With CORES_MN = 128/8 = 16: `core_k * 16 * 64 + core_mn * 64 = core_k * 1024 + core_mn * 64`
Wait — that's the SAME formula! `k * 1024 + g * 64` = `core_k * 1024 + core_mn * 64`. They're identical!
So my layout IS correct. The offset for (r=0, c=8) = 1024 in both formulas.
But the descriptor says LBO = 2048 (128 * 16). The column 1 starts at 2048, not 1024. So the descriptor expects column 1 at offset 2048, but my layout puts it at offset 1024.
**THE DESCRIPTOR IS WRONG FOR MY LAYOUT!** The LBO = BLOCK_MN * 16 = 128 * 16 = 2048 bytes. But my column 1 data starts at 1024 BF16 = 2048 bytes... wait, 1024 BF16 = 2048 bytes. That matches!
For (r=0, c=8): my offset = 1024 BF16 elements = 2048 bytes. The descriptor says LBO = 2048 bytes. So column 1 IS at offset 2048 bytes. ✓
Let me re-check: g=0, k=1, offset = 1 * 1024 BF16 = 2048 bytes. Descriptor LBO = 2048 bytes. ✓
So the layout IS correct and the descriptor matches. The 4× factor must be from something else entirely.
Let me try one more thing: construct the descriptor with **SW128 swizzle** (layout_type=1) and write data in the swizzled layout. This is what gau-nernst's tutorial uses for the actual GEMM, not SWIZZLE_NONE. Maybe the SWIZZLE_NONE path has a 4× quirk.
Actually, let me try the absolute simplest thing: use gau-nernst's EXACT descriptor construction code (hardcoded for our BLOCK_M=128, BLOCK_K=16) and see if it matches:Wait, SBO should be `8 * 16 = 128` bytes, not 128 alone. Let me check gau-nernst: `desc_encode(SBO)` where SBO = 8 * 16 = 128. `desc_encode(128) = 128 >> 4 = 8`. And my `make_uma_desc_kmajor_none` uses `desc_encode(128) = 8`. So the descriptor field for SBO is 8. Let me check my existing code:
In `make_umma_desc_kmajor_none`:
``
desc |= (desc_encode(SBO) & 0x3FFF) < 32 // SBO = 128, field = 8
```
And in gau-nernst's code:
``
desc_encode(128) = 8 → (8) << 32
```
These are the same. Let me verify the LBO too:
- gau-nernst: LBO = BLOCK_M * 16 = 128 * 16 = 2048. desc_encode(2048) = 2048 >> 4 = 128. Field = 128.
- My code: LBO = BLOCK_MN * 16 = 128 * 16 = 2048. desc_encode(2048) = 128. Field = 128.
Same. So the descriptors are identical. The 4× factor is NOT from the descriptor.
Let me revert this change and instead focus on a different hypothesis: **the issue might be with the `kind::f16` MMA instruction itself.** Maybe `kind::f16` computes a different product than I expect. Let me try `kind::f16` with a simpler test where Q is all-1s and K is all-1s, so the expected result is HD=16 for every element.Now let me test with all-1s data. S[i,j] = sum(1*1 for d in 0..15) = 16 for every element:**EUREKA!** With all-ones data:
- MMA output: S[0,0..7] = 16.0
- Scalar reference: S[0,0..7] = 4.0
The expected value is: S[0,j] = sum(1.0 * 1.0, d=0..15) = 16.0. But the scalar reference uses `scale = 1/sqrt(16) = 0.25`, so `16.0 * 0.25 = 4.0`.
The MMA gives 16.0, which is the **UNSCALED** dot product! The scalar reference applies the `1/sqrt(HD)` scale factor, but the MMA does NOT.
**The 4× factor is just the scale factor!** `16.0 / 4.0 = sqrt(16) = sqrt(HD)`. The MMA computes the raw dot product, and we need to apply the `1/sqrt(HD)` scaling ourselves.
Let me verify: with the random data, the ratio was 4.0 because `1/sqrt(16) = 0.25`, so the MMA's unscaled output is 4× the scaled scalar reference. The scale factor is `1/sqrt(16) = 1/4`, so the MMA's output is 1/scale = 4× the scaled reference. ✓
**THE 4× FACTOR IS NOT A BUG — IT'S THE SCALE FACTOR!** The MMA computes the raw dot product without scaling, and we just need to multiply by `1/sqrt(HD)`.
Let me fix the test to apply the scale factor to the MMA output:Now apply the scale factor in the TMEM read:# 🎉🎉🎉 **THE UMMA QK GEMM IS WORKING!!!** 🎉🎉
**Row 0 MMA output EXACTLY matches the scalar reference!** Max relative error: 0.000000!
The "4× factor" was just the `1/sqrt(HD)` attention scale that we weren't applying to the MMA output. The MMA computes the raw dot product; we need to scale it ourselves.
Let me now read more rows and verify the full output. Also let me test with HD=64 (4 K-tiles):Also update the comparison to check ALL rows:Row 0 matches perfectly, but the overall check fails. The issue: my comparison is wrong — I'm comparing `h_s_out[r * 8 + c]` against `h_s_scalar[r]` but `h_scalar[r]` only has the dot product for column 0, not column c. The scalar reference computes S[0, c] for all c, but for other rows (1-127), it also computes S[r, c]. Let me fix the comparison — for row 0 we know it matches, and for rows 1-127 the MMA output should be zero (since Q only has row 0):Stale variable reference. Let me fix:# 🎉🎉🎉 **UMMA QK GEMM TEST PASSED!** 🎉🎉🎉
**Row 0 matches the scalar reference with ZERO error!** All other rows are correctly zero. The UMA pipeline is working!
Key discoveries:
1. **The "4× bug" was the 1/sqrt(HD) attention scale** — the MMA computes raw dot products, we apply scaling
2. **`tcgen05.fence::after_thread_sync`** is the correct fence after MMA (not `tcgen05.wait::st`)
3. **`32x32b.x8` TMEM reads** work correctly for Layout D output
4. **4 warps (128 threads)** needed for M=128 Layout D
Let me now commit a clean version, update the documentation, and then tackle HD=64 (multi-K-tile) and the PV GEMM:Now update CURRENT_ISSUE.md:Now let me implement HD=64 with multi-K-tile accumulate. This is the next critical step:The HD=64 test fails — the MMA output doesn't match the scalar reference. The values are completely different, not just scaled. This suggests the multi-K-tile approach is wrong. The K-tile descriptor offsets are likely incorrect.
Let me think about the SMEM layout for HD=64. The (128, 64) matrix in canonical layout has 8 core-matrix columns (64/8=8). Each core-matrix column is 16 × 64 BF16 = 1024 BF16 = 2048 bytes. The total is 128 × 64 × 2 = 16384 bytes.
For K-tile k (columns [16k, 16k+16)):
- This spans 2 core-matrix columns: core_k = 2k and 2k+1
- core_k 2k starts at offset 2k * 1024 BF16 from sQ
- core_k 2k+1 starts at offset (2k+1) * 1024 BF16 from sQ
The descriptor for this K-tile needs:
- start_address = sQ + 2k * 2048 bytes (the start of core_k=2k)
- LBO = 2048 (stride from core_k=2k to core_k=2k+1)
Wait — the LBO should be the stride between the 2 columns WITHIN the K-tile. For K-tile k:
- Column 0 = core_k 2k (at offset 2k * 2048)
- Column 1 = core_k 2k+1 (at offset (2k+1) * 2048)
- LBO = (2k+1)*2048 - 2k*2048 = 2048 bytes
This is the same LBO as the full matrix. ✓
But there's a problem: the descriptor's LBO stride walks from column 0 to column 1 (2048 bytes apart). After column 1, the next 8-row group's column 0 is at offset g * 128 + 0 * 2048. This is within the K-tile.
But what about the data BETWEEN the K-tile columns? For HD=64, core_k=0 at offset 0, core_k=1 at offset 2048, core_k=2 at offset 4096, etc. The K-tile 0 (core_k=0,1) spans offsets 0-4095. K-tile 1 (core_k=2,3) spans offsets 4096-8191.
The descriptor for K-tile 1: start = 4096, LBO = 2048. It reads:
- g=0, c=0: offset 4096 (core_k=2, g=0) ✓
- g=0, c=1: offset 614 (core_k=3, g=0) ✓
- g=1, c=0: offset 4224 (core_k=2, g=1) ✓
This looks correct! The start address `sQ_smem + kt * 2048` should work.
But wait — my code uses `q_ktile_addr = sQ_smem + kt * 2 * 2048`. The `sQ_smem` is a SMEM address from `__cvta_generic_to_shared(sQ)`. Adding `kt * 2 * 2048` gives a new SMEM address. But SMEM addresses are in a special address space — can we just add byte offsets to them?
Actually, in gau-nernst's code, he uses `A_smem + k * BLOCK_M * 32` where `A_smem` is the raw SMEM address. The `+` operator works on SMEM addresses because they're just integers. So `sQ_smem + kt * 4096` should work (2 core columns × 2048 bytes = 4096 per K-tile).
Wait, `kt * 2 * 2048 = kt * 4096`. And `BLOCK_M * 32 = 128 * 32 = 4096`. These are the same! So the start address should be `sQ_smem + kt * 4096`. ✓
But my code has `q_ktile_addr = sQ_smem + kt * 2 * 2048` = `sQ_smem + kt * 4096`. This matches. So the K-tile descriptor should be correct.
The issue must be elsewhere. Let me check: maybe the `make_umma_desc_kmajor_none` function modifies the address incorrectly. Let me add debug output:
The values are ~5× off, not a constant factor. Something is fundamentally wrong with the multi-K-tile approach.
The MMA gives `-0.156829` for S[0,0] but the scalar says `-0.030706`. The ratio is ~5.1. This is not a clean scaling factor.
Let me check: with ONLY K-tile 0 (1 K-tile), the MMA gave the correct partial sum. So the issue is that the 2nd, 3rd, 4th K-tiles add incorrect values.
The problem might be that the `__syncthreads()` between K-tiles isn't sufficient to ensure the MMA has finished reading from SMEM. The `fence::after_thread_sync` ensures the MMA's TMEM writes are visible, but it doesn't ensure the MMA has finished READING from SMEM. If the next K-tile zeros the SMEM while the current MMA is still reading, the data is corrupted.
In gau-nernst's code, the TMA load pipeline handles this by only loading the next K-tile after the MMA has consumed the current one. The pipeline's "release" callback ensures the SMEM is not overwritten until the MMA is done.
For our simple loop, we need to ensure the MMA has finished reading before zeroing the SMEM for the next K-tile. The `__syncthreads()` after `fence::after_thread_sync` should be sufficient, but maybe it's not.
Actually, wait. In our code, each K-tile uses SEPARATE SMEM regions (`sQ_base + kt * 128 * 16`). So zeroing K-tile 1's SMEM doesn't affect K-tile 0's SMEM. The data shouldn't be corrupted.
Unless the SMEM regions overlap. Let me check: `sQ_base + kt * 128 * 16`. With `128 * 16 = 2048` BF16 = 4096 bytes per K-tile. And `sK_base = sQ_base + NKT * 128 * 16 = sQ_base + 4 * 2048 = sQ_base + 8192 BF16 = sQ_base + 16384 bytes`. So sK starts 16KB after sQ. With 4 K-tiles of 4KB each, sQ uses 16KB. ✓ No overlap.
But wait — the loop constructs descriptors using `__cvta_generic_to_shared(sQ)` where `sQ = sQ_base + kt * 128 * 16`. The `__cvta_generic_to_shared` converts a GENERIC pointer to a SHARED pointer. But adding an offset to a generic pointer and then converting to shared might not give the correct shared address.
In CUDA, `__cvta_generic_to_shared` should work correctly for any generic pointer that points to shared memory. Adding `kt * 128 * 16 * sizeof(bf16_t)` to the generic pointer should give a valid generic pointer to the K-tile's SMEM.
Hmm, but maybe the issue is that the `sQ` and `sK` pointers are inside the loop, and the compiler is caching the `__cvta_generic_to_shared` result from the first iteration. Each iteration computes a NEW `sQ` pointer, but the compiler might not re-derive the shared address.
This is unlikely but possible. Let me try: construct the descriptors OUTSIDE the loop and pass them in an array:
- `dsv4/kernels/attention/fmha_common.cuh`BF16, TMEM ops, warp reductions
- `dsv4/kernels/attention/fmha_epilogue_sm100.cuh` — TMEM one-way epilogue (reference kernel)
- `tests/unit/test_fmha_hd64.cu` — FULL PIPELINE TEST (QK+softmax+PV)
- `tests/unit/test_fmha_softmax.cu` — softmax validation (HD=64)
- `tests/unit/test_umma_qk.cu` — QK GEMM (HD=16)
- `tests/unit/test_umma_qk_hd64.cu` — QK GEMM (HD=64, multi-K-tile)