docs: update CURRENT_ISSUE and MEMORY — full FMHA HD=64 pipeline working
This commit is contained in:
360
CURRENT_ISSUE.md
360
CURRENT_ISSUE.md
@@ -1,330 +1,52 @@
|
||||
# CURRENT ISSUE: UMMA FMHA — Multi-K-tile + PV GEMM + Full Pipeline
|
||||
# CURRENT ISSUE: UMMA FMHA — Full Pipeline + Production
|
||||
|
||||
## What's working ✅
|
||||
- **UMMA QK GEMM at HD=16, SK=128**: Row 0 matches scalar reference with ZERO error
|
||||
- **SMEM canonical layout**: column-major interleaving of 8×8 BF16 core matrices
|
||||
- **K-major NONE descriptors**: LBO=BLOCK_MN*16, SBO=128, lbo_mode=0, layout_type=0
|
||||
- **TMEM Layout D reads**: `tcgen05.ld.32x32b.x8.b32` with `addr = tmem_base + (row<<16) + col`
|
||||
- **MMA→TMEM fence**: `tcgen05.fence::after_thread_sync` (not `tcgen05.wait::st`)
|
||||
- **MMA computes raw dot product** — apply 1/sqrt(HD) scaling in the read path
|
||||
- **Full UMMA FMHA HD=64 pipeline**: QK → softmax → PV → output
|
||||
- QK GEMM: UMMA SS, multi-K-tile accumulate, cos 0.999999
|
||||
- Softmax: TMEM read → max/exp/sum → TMEM write, max rel err 0.003
|
||||
- PV: register math (O[d] = Σ P[0,j] × V[d,j]), decode only
|
||||
- **Overall: cosine 0.999998!**
|
||||
- **UMMA QK GEMM at HD=16**: Row 0 matches scalar with ZERO error
|
||||
- **TMEM one-way epilogue**: cos 0.999999 at hd=64, cos 0.999998 at hd=128
|
||||
- **32x32b.x8 TMEM stores** work for P write-back (no 16x256b crash)
|
||||
|
||||
## Next steps
|
||||
1. **HD=64 multi-K-tile**: Call MMA 4× with accumulate=true for K=64 (4 × K=16 tiles)
|
||||
- Each K-tile needs its own descriptor pointing to the right 16-column slice
|
||||
- gau-nernst pattern: `A_smem + k * BLOCK_M * 32` for the k-th K-tile start address
|
||||
- After all K-tiles: read TMEM and apply 1/sqrt(HD) scaling
|
||||
1. **PV GEMM via tcgen05.mma TS**: For prefill (T>1), need UMMA-based PV
|
||||
- tcgen05.mma TS crashed with "illegal memory access" in initial tests
|
||||
- Need to debug the TMEM A operand addressing for PV
|
||||
- Decode path (T=1) works with register math — PV GEMM is only for prefill
|
||||
|
||||
2. **PV GEMM**: `tcgen05.mma TS` (TMEM P × SMEM V → TMEM O)
|
||||
- P is in TMEM after softmax, V is in SMEM
|
||||
- Accumulate O across KV tiles with the D5 merge formula
|
||||
2. **HD=128/256**: Extend multi-K-tile QK to larger head dims
|
||||
- HD=128: 8 K-tiles, separate SMEM per K-tile
|
||||
- HD=256: 16 K-tiles — SMEM budget needs checking
|
||||
|
||||
3. **In-kernel softmax**: TMEM → regs → max/exp/sum → TMEM
|
||||
- Use 32x32b reads to get S, compute softmax, write P back via 32x32b stores
|
||||
- Must handle the TMEM multi-store issue (use 32x32b, not 16x256b)
|
||||
3. **Multi-head launch**: Per-head kernel dispatch (128 heads for Pro)
|
||||
- Current test: single head
|
||||
- Production: grid=(1, n_h, batch) or Python loop
|
||||
|
||||
4. **Full FMHA pipeline**: QK → softmax → PV → correction epilogue → GMEM output
|
||||
4. **Multi-KV-tile**: s_k > 128 requires multiple attention tiles + KV merge
|
||||
- Same D5 merge formula: O = Σ exp(lse_i) · O_i / Σ exp(lse_i)
|
||||
|
||||
## Key lessons learned
|
||||
- **16x256b.x1 TMEM stores crash on 2nd call** — use 32x32b format for multi-store
|
||||
- **MMA output is UNSCALED** — the 4× "bug" was just the 1/sqrt(HD) attention scale
|
||||
- **`tcgen05.fence::after_thread_sync`** is the correct MMA→TMEM load fence
|
||||
- **4 warps minimum** for M=128 Layout D (each warp reads 32 rows × 8 columns)
|
||||
- **MMA K-tile size is 16 BF16** — for HD>16, loop with accumulate
|
||||
- **TMEM address format**: bits [31:16] = row, bits [15:0] = column
|
||||
5. **Production kernel**: Integrate UMMA pipeline into fmha_sm100.cuh
|
||||
- Replace SMEM scalar attention with UMMA QK GEMM
|
||||
- Keep register-math PV for decode
|
||||
- Add PV GEMM path for prefill
|
||||
|
||||
## Key lessons from this session
|
||||
- **Source stride mismatch**: `write_k_to_smem<128,16>` template reads k[i] with stride=16, but actual K has stride=64. Must use separate SMEM per K-tile with manual writes.
|
||||
- **Offset descriptors DON'T WORK** for (128,64) SMEM with K-tile offsets. Even/odd column corruption. Use separate (128,16) SMEM buffers.
|
||||
- **tcgen05.mma TS crashes** with illegal memory access — likely TMEM A operand addressing issue. Debug needed.
|
||||
- **tcgen05.fence::after_thread_sync** is the correct fence after MMA, before TMEM read.
|
||||
- **MMA computes UNSCALED dot product** — apply 1/sqrt(HD) in softmax.
|
||||
- **32x32b.x8 TMEM stores** work in loops (unlike 16x256b.x1 which crashes on 2nd call)
|
||||
- **s_vals scope**: when merging softmax + PV, keep s_vals in same `if (wid == 0)` block
|
||||
- **Decode PV is simple**: O[d] = Σ P[0,j] × V[d,j] — no UMMA needed for T=1
|
||||
|
||||
## Files
|
||||
- `dsv4/kernels/attention/fmha_umma_desc.cuh` — descriptors, SMEM layout, MMA wrappers
|
||||
- `tests/unit/test_umma_qk.cu` — working UMMA QK GEMM test
|
||||
|
||||
|
||||
YOUR LAST MESSAGE TO ME:
|
||||
|
||||
Let me study the gau-nernst tutorial's GEMM output read path more carefully, specifically how it handles the TMEM Layout D coordinate mapping.**CRITICAL FINDING from gau-nernst's code!** Look at his MMA epilogue:
|
||||
|
||||
```cpp
|
||||
// this is required before tcgen05.ld and after tcgen05.mma
|
||||
asm volatile("tcgen05.fence::after_thread_sync;");
|
||||
|
||||
// load 8 columns from tmem at a time
|
||||
for (int n = 0; n < BLOCK_N / 8; n++) {
|
||||
float tmp[8];
|
||||
const int row = warp_id * 32;
|
||||
const int col = n * 8;
|
||||
const int addr = taddr + (row << 16) + col;
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 ...");
|
||||
|
||||
// wait for tcgen05.ld
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
|
||||
// cast and pack
|
||||
nv_bfloat162 out[4];
|
||||
for (int i = 0; i < 4; i++)
|
||||
out[i] = __float22bfloat162_rn({tmp[i*2], tmp[i*2+1]});
|
||||
|
||||
// 16-byte per thread write (uncoalesced)
|
||||
nv_bfloat16 *out_ptr = C_ptr + (off_m + tid) * N + (off_n + n * 8);
|
||||
reinterpret_cast<int4 *>(out_ptr)[0] = reinterpret_cast<int4 *>(out)[0];
|
||||
}
|
||||
```
|
||||
|
||||
Key differences from our code:
|
||||
1. **`tcgen05.fence::after_thread_sync`** — used AFTER MMA, BEFORE TMEM load. We used `tcgen05.wait::st.sync.aligned` which is different!
|
||||
2. **`off_m + tid`** — the output row is `off_m + tid`, where `tid = threadIdx.x`. Each thread writes one row of the output. With 128 threads and BLOCK_M=128, each thread handles one row.
|
||||
3. **Multiple `tcgen05.ld` calls in a LOOP** — he reads 8 columns at a time and loops over BLOCK_N/8 iterations. No crash!
|
||||
4. **No 4× factor** — he reads `tmp[i]` directly as the FP32 accumulator values and writes them to C. No division by 4.
|
||||
|
||||
The 4× factor in our code must be because we're reading the WRONG rows from TMEM. In gau-nernst's code, `row = warp_id * 32` and `col = n * 8`. Each warp reads its 32-row slice. With 4 warps (128 threads), all 128 rows are covered.
|
||||
|
||||
But in our code, we also use `row = wid * 32` and `col = 0`. The difference: gau-nernst uses `tcgen05.fence::after_thread_sync` before reading, while we use `tcgen05.wait::st.sync.aligned`. The `fence::after_thread_sync` is the CORRECT fence for the MMA → TMEM load transition!
|
||||
|
||||
Let me fix our test:The `tcgen05.fence::after_thread_sync` didn't fix the 4× factor — still exactly 4.000. But notice: rows 32, 64, 96 are all zeros now (correct for Q with only row 0). The TMEM read is working correctly for the zero rows.
|
||||
|
||||
The 4× factor is ONLY in the MMA output for row 0. Let me think about what's different between our setup and gau-nernst's working GEMM:
|
||||
|
||||
1. gau-nernst uses **128 threads (4 warps)** and **each warp's lane 0 reads the TMEM**
|
||||
2. gau-nernst's output writes `C_ptr + (off_m + tid) * N + ...` — each thread handles one row
|
||||
3. In our test, we read with `lane == 0` only — we're only getting lane 0's perspective
|
||||
|
||||
Wait — in the `32x32b.x8` format, **each lane reads a different row**. Lane 0 reads row `wid*32 + 0`, lane 1 reads row `wid*32 + 1`, etc. So when I only write `lane == 0`'s data, I only get 1 out of 32 rows from each warp.
|
||||
|
||||
But the 4× factor is for lane 0's data, which should be S[wid*32 + 0, 0..7]. For warp 0, that's S[0, 0..7]. This should match the scalar reference for S[0, 0..7].
|
||||
|
||||
Unless the `32x32b.x8` format doesn't map lanes to rows 1:1. Let me check: gau-nernst writes `C_ptr + (off_m + tid) * N + (off_n + n * 8)`. Here, `off_m = 0` for the first MMA call, and `tid = threadIdx.x`. Each thread writes ONE ROW. So the `32x32b.x8` format gives each thread 8 values for its row.
|
||||
|
||||
But he uses `tid` not `lane` — `tid = threadIdx.x`. With 128 threads (4 warps), `tid` ranges from 0 to 127. Each thread gets 8 FP32 values for its row. So lane 0 of warp 0 gets row 0, lane 1 gets row 1, ..., lane 31 gets row 31, lane 0 of warp 1 gets row 32, etc.
|
||||
|
||||
This means: the `32x32b.x8` instruction gives each lane the 8 column values for its row. Lane 0 of warp 0 gets S[0, col..col+7]. The 4× factor is in these values.
|
||||
|
||||
The 4× factor MUST be from the MMA computation. Since gau-nernst's code produces the correct output (no 4×), the difference must be in how the MMA is called.
|
||||
|
||||
Let me compare the MMA calls:
|
||||
- Ours: `umma_ss_f16(tb, desc_q, desc_k, idesc, false)` — single thread calls
|
||||
- gau-nernst: `tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, 1` — one thread per 4-lane group calls
|
||||
|
||||
Wait — gau-nernst uses `elect_one_sync(0x11111111, warp_idx, 4)` which selects ONE leader per warp. But he has `num_warps = BLOCK_M / 32 = 4` warps, and each warp has 1 leader. So 4 leaders call the MMA.
|
||||
|
||||
In our test, we have `tid == 0` (1 thread) calling the MMA. But gau-nernst has 4 threads (one per warp) calling it simultaneously. The MMA with `cta_group::1` requires ALL 4 warps to participate in the instruction.
|
||||
|
||||
The 4× factor might be because our single-thread MMA launches only 1 of the 4 required sub-operations, and the hardware amplifies the result by 4x to compensate.
|
||||
|
||||
Let me fix: have 4 warp leaders (lane 0 of each warp) call the MMA simultaneously:Still 4× with 4 warp leaders! The rows 32, 64, 96 are correctly zero. Row 0 has the 4× factor.
|
||||
|
||||
Wait — with 4 warp leaders all calling `umma_ss_f16(tb, ...)` with `accumulate=false`, each leader RESETS the TMEM accumulator to A×B^T. If all 4 leaders write to the SAME TMEM columns, the last one wins (or they all write the same result 4 times). The output should be 1× the scalar, not 4×.
|
||||
|
||||
Unless the MMA with `cta_group::1` internally uses 4 sub-operations, and each sub-operation writes to different TMEM columns. With 4 leaders calling, the 4 sub-operations are properly distributed. With 1 leader, the sub-operations are serialized.
|
||||
|
||||
But the 4× factor with both 1 leader and 4 leaders is the same. This rules out the "warpgroup count" hypothesis.
|
||||
|
||||
Let me look at this from a completely different angle. Maybe the issue is that **the descriptor describes a (128, 16) matrix, but the MMA computes as if K=64** because of how the hardware interprets the descriptor strides.
|
||||
|
||||
For our descriptor: LBO=128 (2048 bytes), SBO=8 (128 bytes). These strides describe a (128, 16) matrix in canonical layout. But the MMA might interpret LBO=128 as meaning "128 rows in the 16B column", and SBO=8 as "8 sub-tiles in the K direction". With 8 sub-tiles × 8 BF16 per sub-tile = K=64.
|
||||
|
||||
**This would give K=64 (4× our K=16), explaining the 4× factor!**
|
||||
|
||||
The fix: the descriptor should describe ONLY one K-tile (K=16), not the full matrix. The LBO and SBO values should describe a (128, 16) matrix with the MMA operating on K=16, not K=64.
|
||||
|
||||
But wait — gau-nernst uses the SAME descriptor format for his working GEMM. He has LBO = BLOCK_M * 16 and SBO = 128. And his BLOCK_K = 16 for each MMA call. His MMA produces the correct result.
|
||||
|
||||
The difference: gau-nernst uses `A_smem + k * BLOCK_M * 32` for the start address. Each MMA call gets a descriptor pointing to a different K-slice. The `k * BLOCK_M * 32` offset moves the start address by one K-tile (32 bytes = 16 BF16).
|
||||
|
||||
For our matrix with K=16, we only have ONE K-tile. The descriptor points to the start of the matrix (column 0). The MMA reads K=16 from this descriptor.
|
||||
|
||||
Unless the SBO=8 causes the MMA to read 8 sub-tiles. Let me check: for gau-nernst's BLOCK_K=16, the SBO should be... actually, SBO is always 128 bytes (8 × 16B) regardless of BLOCK_K. It's the stride between 8-row groups, not related to K.
|
||||
|
||||
I'm going in circles again. Let me try a completely different approach: **use the SMEM data layout from gau-nernst's exact code** (2D TMA load) instead of our manual `write_smem_canonical`. Maybe our SMEM layout is subtly wrong.
|
||||
|
||||
Actually — the simplest thing to test: write a minimal GEMM that matches gau-nernst's EXACT code. Same BLOCK_M=128, BLOCK_N=128, BLOCK_K=16. Same descriptor construction. Same SMEM layout (loaded via 2D TMA). Same MMA call. Same TMEM read.
|
||||
|
||||
But we can't use 2D TMA in our standalone test (it requires CuTeDSL or complex descriptor setup). Let me instead try: write the SMEM data in SIMPLE ROW-MAJOR and use SWIZZLE_NONE with a descriptor that describes row-major layout.
|
||||
|
||||
Actually, wait. I just realized something. gau-nernst's SMEM is loaded by 2D TMA, which writes data in the CANONICAL layout automatically. The MMA descriptor then describes this canonical layout. When we write data manually with `write_smem_canonical`, we're replicating the TMA's output format.
|
||||
|
||||
But maybe `write_smem_canonical` is subtly wrong. The canonical layout depends on the exact `logical_divide` of the swizzled layout. For SWIZZLE_NONE, the canonical layout is:
|
||||
- LBO = BLOCK_MN * 16 (in bytes) = stride between the 2 columns of the 8×2 core matrix
|
||||
- SBO = 128 (in bytes) = stride between 8-row groups
|
||||
|
||||
The data in SMEM at address `g * SBO + c * LBO + local_offset` where g is the 8-row group and c is the column.
|
||||
|
||||
For (128, 16) with BLOCK_M=128:
|
||||
- g ranges from 0 to 15 (128/8=16 groups)
|
||||
- c ranges from 0 to 1 (16/8=2 columns)
|
||||
- Each core matrix is 8 rows × 8 BF16 = 128 bytes
|
||||
|
||||
My `write_q_to_smem` writes: core_k * 16 * 64 + tile_mn * 64 + local_r * 8 + local_c
|
||||
= core_k * 1024 + tile_mn * 64 + local_r * 8 + local_c
|
||||
|
||||
For tile_mn = r/8 and core_k = c/8. The `16` in `16 * 64` is CORES_MN = 128/8 = 16. So:
|
||||
core_k * 16 * 64 = core_k * 1024
|
||||
|
||||
The descriptor interprets: group g at offset g * 128, column c at offset c * 2048.
|
||||
For g=0, c=0: offset 0
|
||||
For g=0, c=1: offset 2048
|
||||
For g=1, c=0: offset 128
|
||||
|
||||
My layout: tile_mn=0, core_k=0: offset 0 ✓
|
||||
tile_mn=0, core_k=1: offset 1024. But descriptor says column 1 is at offset 2048!
|
||||
|
||||
**THAT'S THE BUG!** My `write_q_to_smem` puts core_k=1 at offset 1024, but the descriptor expects it at offset 2048! The layout is wrong!
|
||||
|
||||
My formula: `core_k * 16 * 64 + tile_mn * 64 = core_k * 1024 + tile_mn * 64`
|
||||
|
||||
The descriptor's layout: `c * LBO + g * SBO = c * 2048 + g * 128`
|
||||
|
||||
For (tile_mn=0, core_k=1): my offset = 1 * 1024 = 1024. Descriptor offset = 1 * 2048 = 2048. **MISMATCH!**
|
||||
|
||||
My `write_smem_canonical` puts core matrices in MN-major order (all tile_mn for core_k=0, then all tile_mn for core_k=1). But the descriptor expects K-major order (column 0 of all groups, then column 1 of all groups, with LBO stride between columns).
|
||||
|
||||
The correct formula should be: `core_k * 128 * 16 + tile_mn * 128 + local_r * 8 + local_c`... wait, that's not right either.
|
||||
|
||||
Let me derive the correct SMEM layout from the descriptor strides:
|
||||
|
||||
For element (r, c) in the (128, 16) matrix:
|
||||
- g = r / 8 (8-row group)
|
||||
- local_r = r % 8
|
||||
- k = c / 8 (K-group)
|
||||
- local_c = c % 8
|
||||
- Address = g * SBO + k * LBO + local_r * 16B + local_c * 2B
|
||||
|
||||
In BF16 elements: g * 64 + k * (LBO/2) + local_r * 8 + local_c
|
||||
= g * 64 + k * 1024 + local_r * 8 + local_c
|
||||
|
||||
For (r=0, c=8): g=0, k=1, local_r=0, local_c=0 → offset = 0 + 1024 + 0 + 0 = 1024
|
||||
For (r=0, c=0): g=0, k=0, local_r=0, local_c=0 → offset = 0
|
||||
|
||||
For (r=8, c=0): g=1, k=0, local_r=0, local_c=0 → offset = 64
|
||||
|
||||
So the correct layout is: `k * 1024 + g * 64 + local_r * 8 + local_c`
|
||||
|
||||
My `write_smem_canonical` uses: `core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c`
|
||||
|
||||
With CORES_MN = 128/8 = 16: `core_k * 16 * 64 + core_mn * 64 = core_k * 1024 + core_mn * 64`
|
||||
|
||||
Wait — that's the SAME formula! `k * 1024 + g * 64` = `core_k * 1024 + core_mn * 64`. They're identical!
|
||||
|
||||
So my layout IS correct. The offset for (r=0, c=8) = 1024 in both formulas.
|
||||
|
||||
But the descriptor says LBO = 2048 (128 * 16). The column 1 starts at 2048, not 1024. So the descriptor expects column 1 at offset 2048, but my layout puts it at offset 1024.
|
||||
|
||||
**THE DESCRIPTOR IS WRONG FOR MY LAYOUT!** The LBO = BLOCK_MN * 16 = 128 * 16 = 2048 bytes. But my column 1 data starts at 1024 BF16 = 2048 bytes... wait, 1024 BF16 = 2048 bytes. That matches!
|
||||
|
||||
For (r=0, c=8): my offset = 1024 BF16 elements = 2048 bytes. The descriptor says LBO = 2048 bytes. So column 1 IS at offset 2048 bytes. ✓
|
||||
|
||||
Let me re-check: g=0, k=1, offset = 1 * 1024 BF16 = 2048 bytes. Descriptor LBO = 2048 bytes. ✓
|
||||
|
||||
So the layout IS correct and the descriptor matches. The 4× factor must be from something else entirely.
|
||||
|
||||
Let me try one more thing: construct the descriptor with **SW128 swizzle** (layout_type=1) and write data in the swizzled layout. This is what gau-nernst's tutorial uses for the actual GEMM, not SWIZZLE_NONE. Maybe the SWIZZLE_NONE path has a 4× quirk.
|
||||
|
||||
Actually, let me try the absolute simplest thing: use gau-nernst's EXACT descriptor construction code (hardcoded for our BLOCK_M=128, BLOCK_K=16) and see if it matches:Wait, SBO should be `8 * 16 = 128` bytes, not 128 alone. Let me check gau-nernst: `desc_encode(SBO)` where SBO = 8 * 16 = 128. `desc_encode(128) = 128 >> 4 = 8`. And my `make_uma_desc_kmajor_none` uses `desc_encode(128) = 8`. So the descriptor field for SBO is 8. Let me check my existing code:
|
||||
|
||||
In `make_umma_desc_kmajor_none`:
|
||||
``
|
||||
desc |= (desc_encode(SBO) & 0x3FFF) < 32 // SBO = 128, field = 8
|
||||
```
|
||||
|
||||
And in gau-nernst's code:
|
||||
``
|
||||
desc_encode(128) = 8 → (8) << 32
|
||||
```
|
||||
|
||||
These are the same. Let me verify the LBO too:
|
||||
- gau-nernst: LBO = BLOCK_M * 16 = 128 * 16 = 2048. desc_encode(2048) = 2048 >> 4 = 128. Field = 128.
|
||||
- My code: LBO = BLOCK_MN * 16 = 128 * 16 = 2048. desc_encode(2048) = 128. Field = 128.
|
||||
|
||||
Same. So the descriptors are identical. The 4× factor is NOT from the descriptor.
|
||||
|
||||
Let me revert this change and instead focus on a different hypothesis: **the issue might be with the `kind::f16` MMA instruction itself.** Maybe `kind::f16` computes a different product than I expect. Let me try `kind::f16` with a simpler test where Q is all-1s and K is all-1s, so the expected result is HD=16 for every element.Now let me test with all-1s data. S[i,j] = sum(1*1 for d in 0..15) = 16 for every element:**EUREKA!** With all-ones data:
|
||||
- MMA output: S[0,0..7] = 16.0
|
||||
- Scalar reference: S[0,0..7] = 4.0
|
||||
|
||||
The expected value is: S[0,j] = sum(1.0 * 1.0, d=0..15) = 16.0. But the scalar reference uses `scale = 1/sqrt(16) = 0.25`, so `16.0 * 0.25 = 4.0`.
|
||||
|
||||
The MMA gives 16.0, which is the **UNSCALED** dot product! The scalar reference applies the `1/sqrt(HD)` scale factor, but the MMA does NOT.
|
||||
|
||||
**The 4× factor is just the scale factor!** `16.0 / 4.0 = sqrt(16) = sqrt(HD)`. The MMA computes the raw dot product, and we need to apply the `1/sqrt(HD)` scaling ourselves.
|
||||
|
||||
Let me verify: with the random data, the ratio was 4.0 because `1/sqrt(16) = 0.25`, so the MMA's unscaled output is 4× the scaled scalar reference. The scale factor is `1/sqrt(16) = 1/4`, so the MMA's output is 1/scale = 4× the scaled reference. ✓
|
||||
|
||||
**THE 4× FACTOR IS NOT A BUG — IT'S THE SCALE FACTOR!** The MMA computes the raw dot product without scaling, and we just need to multiply by `1/sqrt(HD)`.
|
||||
|
||||
Let me fix the test to apply the scale factor to the MMA output:Now apply the scale factor in the TMEM read:# 🎉🎉🎉 **THE UMMA QK GEMM IS WORKING!!!** 🎉🎉
|
||||
|
||||
**Row 0 MMA output EXACTLY matches the scalar reference!** Max relative error: 0.000000!
|
||||
|
||||
The "4× factor" was just the `1/sqrt(HD)` attention scale that we weren't applying to the MMA output. The MMA computes the raw dot product; we need to scale it ourselves.
|
||||
|
||||
Let me now read more rows and verify the full output. Also let me test with HD=64 (4 K-tiles):Also update the comparison to check ALL rows:Row 0 matches perfectly, but the overall check fails. The issue: my comparison is wrong — I'm comparing `h_s_out[r * 8 + c]` against `h_s_scalar[r]` but `h_scalar[r]` only has the dot product for column 0, not column c. The scalar reference computes S[0, c] for all c, but for other rows (1-127), it also computes S[r, c]. Let me fix the comparison — for row 0 we know it matches, and for rows 1-127 the MMA output should be zero (since Q only has row 0):Stale variable reference. Let me fix:# 🎉🎉🎉 **UMMA QK GEMM TEST PASSED!** 🎉🎉🎉
|
||||
|
||||
**Row 0 matches the scalar reference with ZERO error!** All other rows are correctly zero. The UMA pipeline is working!
|
||||
|
||||
Key discoveries:
|
||||
1. **The "4× bug" was the 1/sqrt(HD) attention scale** — the MMA computes raw dot products, we apply scaling
|
||||
2. **`tcgen05.fence::after_thread_sync`** is the correct fence after MMA (not `tcgen05.wait::st`)
|
||||
3. **`32x32b.x8` TMEM reads** work correctly for Layout D output
|
||||
4. **4 warps (128 threads)** needed for M=128 Layout D
|
||||
|
||||
Let me now commit a clean version, update the documentation, and then tackle HD=64 (multi-K-tile) and the PV GEMM:Now update CURRENT_ISSUE.md:Now let me implement HD=64 with multi-K-tile accumulate. This is the next critical step:The HD=64 test fails — the MMA output doesn't match the scalar reference. The values are completely different, not just scaled. This suggests the multi-K-tile approach is wrong. The K-tile descriptor offsets are likely incorrect.
|
||||
|
||||
Let me think about the SMEM layout for HD=64. The (128, 64) matrix in canonical layout has 8 core-matrix columns (64/8=8). Each core-matrix column is 16 × 64 BF16 = 1024 BF16 = 2048 bytes. The total is 128 × 64 × 2 = 16384 bytes.
|
||||
|
||||
For K-tile k (columns [16k, 16k+16)):
|
||||
- This spans 2 core-matrix columns: core_k = 2k and 2k+1
|
||||
- core_k 2k starts at offset 2k * 1024 BF16 from sQ
|
||||
- core_k 2k+1 starts at offset (2k+1) * 1024 BF16 from sQ
|
||||
|
||||
The descriptor for this K-tile needs:
|
||||
- start_address = sQ + 2k * 2048 bytes (the start of core_k=2k)
|
||||
- LBO = 2048 (stride from core_k=2k to core_k=2k+1)
|
||||
|
||||
Wait — the LBO should be the stride between the 2 columns WITHIN the K-tile. For K-tile k:
|
||||
- Column 0 = core_k 2k (at offset 2k * 2048)
|
||||
- Column 1 = core_k 2k+1 (at offset (2k+1) * 2048)
|
||||
- LBO = (2k+1)*2048 - 2k*2048 = 2048 bytes
|
||||
|
||||
This is the same LBO as the full matrix. ✓
|
||||
|
||||
But there's a problem: the descriptor's LBO stride walks from column 0 to column 1 (2048 bytes apart). After column 1, the next 8-row group's column 0 is at offset g * 128 + 0 * 2048. This is within the K-tile.
|
||||
|
||||
But what about the data BETWEEN the K-tile columns? For HD=64, core_k=0 at offset 0, core_k=1 at offset 2048, core_k=2 at offset 4096, etc. The K-tile 0 (core_k=0,1) spans offsets 0-4095. K-tile 1 (core_k=2,3) spans offsets 4096-8191.
|
||||
|
||||
The descriptor for K-tile 1: start = 4096, LBO = 2048. It reads:
|
||||
- g=0, c=0: offset 4096 (core_k=2, g=0) ✓
|
||||
- g=0, c=1: offset 614 (core_k=3, g=0) ✓
|
||||
- g=1, c=0: offset 4224 (core_k=2, g=1) ✓
|
||||
|
||||
This looks correct! The start address `sQ_smem + kt * 2048` should work.
|
||||
|
||||
But wait — my code uses `q_ktile_addr = sQ_smem + kt * 2 * 2048`. The `sQ_smem` is a SMEM address from `__cvta_generic_to_shared(sQ)`. Adding `kt * 2 * 2048` gives a new SMEM address. But SMEM addresses are in a special address space — can we just add byte offsets to them?
|
||||
|
||||
Actually, in gau-nernst's code, he uses `A_smem + k * BLOCK_M * 32` where `A_smem` is the raw SMEM address. The `+` operator works on SMEM addresses because they're just integers. So `sQ_smem + kt * 4096` should work (2 core columns × 2048 bytes = 4096 per K-tile).
|
||||
|
||||
Wait, `kt * 2 * 2048 = kt * 4096`. And `BLOCK_M * 32 = 128 * 32 = 4096`. These are the same! So the start address should be `sQ_smem + kt * 4096`. ✓
|
||||
|
||||
But my code has `q_ktile_addr = sQ_smem + kt * 2 * 2048` = `sQ_smem + kt * 4096`. This matches. So the K-tile descriptor should be correct.
|
||||
|
||||
The issue must be elsewhere. Let me check: maybe the `make_umma_desc_kmajor_none` function modifies the address incorrectly. Let me add debug output:
|
||||
|
||||
|
||||
The values are ~5× off, not a constant factor. Something is fundamentally wrong with the multi-K-tile approach.
|
||||
|
||||
The MMA gives `-0.156829` for S[0,0] but the scalar says `-0.030706`. The ratio is ~5.1. This is not a clean scaling factor.
|
||||
|
||||
Let me check: with ONLY K-tile 0 (1 K-tile), the MMA gave the correct partial sum. So the issue is that the 2nd, 3rd, 4th K-tiles add incorrect values.
|
||||
|
||||
The problem might be that the `__syncthreads()` between K-tiles isn't sufficient to ensure the MMA has finished reading from SMEM. The `fence::after_thread_sync` ensures the MMA's TMEM writes are visible, but it doesn't ensure the MMA has finished READING from SMEM. If the next K-tile zeros the SMEM while the current MMA is still reading, the data is corrupted.
|
||||
|
||||
In gau-nernst's code, the TMA load pipeline handles this by only loading the next K-tile after the MMA has consumed the current one. The pipeline's "release" callback ensures the SMEM is not overwritten until the MMA is done.
|
||||
|
||||
For our simple loop, we need to ensure the MMA has finished reading before zeroing the SMEM for the next K-tile. The `__syncthreads()` after `fence::after_thread_sync` should be sufficient, but maybe it's not.
|
||||
|
||||
Actually, wait. In our code, each K-tile uses SEPARATE SMEM regions (`sQ_base + kt * 128 * 16`). So zeroing K-tile 1's SMEM doesn't affect K-tile 0's SMEM. The data shouldn't be corrupted.
|
||||
|
||||
Unless the SMEM regions overlap. Let me check: `sQ_base + kt * 128 * 16`. With `128 * 16 = 2048` BF16 = 4096 bytes per K-tile. And `sK_base = sQ_base + NKT * 128 * 16 = sQ_base + 4 * 2048 = sQ_base + 8192 BF16 = sQ_base + 16384 bytes`. So sK starts 16KB after sQ. With 4 K-tiles of 4KB each, sQ uses 16KB. ✓ No overlap.
|
||||
|
||||
But wait — the loop constructs descriptors using `__cvta_generic_to_shared(sQ)` where `sQ = sQ_base + kt * 128 * 16`. The `__cvta_generic_to_shared` converts a GENERIC pointer to a SHARED pointer. But adding an offset to a generic pointer and then converting to shared might not give the correct shared address.
|
||||
|
||||
In CUDA, `__cvta_generic_to_shared` should work correctly for any generic pointer that points to shared memory. Adding `kt * 128 * 16 * sizeof(bf16_t)` to the generic pointer should give a valid generic pointer to the K-tile's SMEM.
|
||||
|
||||
Hmm, but maybe the issue is that the `sQ` and `sK` pointers are inside the loop, and the compiler is caching the `__cvta_generic_to_shared` result from the first iteration. Each iteration computes a NEW `sQ` pointer, but the compiler might not re-derive the shared address.
|
||||
|
||||
This is unlikely but possible. Let me try: construct the descriptors OUTSIDE the loop and pass them in an array:
|
||||
- `dsv4/kernels/attention/fmha_common.cuh` — BF16, TMEM ops, warp reductions
|
||||
- `dsv4/kernels/attention/fmha_epilogue_sm100.cuh` — TMEM one-way epilogue (reference kernel)
|
||||
- `tests/unit/test_fmha_hd64.cu` — FULL PIPELINE TEST (QK+softmax+PV)
|
||||
- `tests/unit/test_fmha_softmax.cu` — softmax validation (HD=64)
|
||||
- `tests/unit/test_umma_qk.cu` — QK GEMM (HD=16)
|
||||
- `tests/unit/test_umma_qk_hd64.cu` — QK GEMM (HD=64, multi-K-tile)
|
||||
|
||||
Reference in New Issue
Block a user