Milestone: Full FMHA HD=16 with PV SS MMA (SMEM-P) — cosine 0.9997
This commit is contained in:
146
CURRENT_ISSUE.md
146
CURRENT_ISSUE.md
@@ -1,134 +1,32 @@
|
||||
Great progress today, Mike. Here's where we stand:
|
||||
# CURRENT_ISSUE.md — PV GEMM for Prefill
|
||||
|
||||
## ✅ Full UMMA FMHA HD=64 Pipeline — WORKING
|
||||
|
||||
**Cosine similarity: 0.999998** — QK GEMM → softmax → PV → output, all on Blackwell SM100 tensor cores.
|
||||
## Status: ✅ PV via SS MMA WORKING (cosine 0.9997)
|
||||
|
||||
### What we proved:
|
||||
1. **UMMA QK GEMM** (HD=16 + HD=64 multi-K-tile): tcgen05.mma SS with separate SMEM per K-tile, accumulate across K-tiles
|
||||
2. **TMEM softmax**: Read S from TMEM → max/exp/sum → write P back via 32x32b.x8 stores
|
||||
3. **PV via register math** (decode T=1): O[d] = Σ P[j] × V[d,j], computed directly in registers
|
||||
1. **tcgen05.mma TS (TMEM A operand) CANNOT be used** — the 32x32b store format doesn't match the TS MMA's A-fragment layout (Layout A). Even though the isolated test_mma_ts.cu works with uniform data, non-uniform data produces garbage because the TMEM column mapping differs.
|
||||
|
||||
### Key bugs found and fixed:
|
||||
- **Source stride mismatch**: The `write_k_to_smem` template used template HD as the data stride — wrong when SMEM tile width ≠ actual data width. Fixed with separate SMEM per K-tile + manual writes.
|
||||
- **Offset descriptors broken**: Using `(128,64)` SMEM with offset-based descriptors produces even/odd column corruption. Must use separate `(128,16)` SMEM buffers per K-tile.
|
||||
- **tcgen05.mma TS crashes**: The PV GEMM with TMEM A operand hits illegal memory access — still needs debugging for prefill path.
|
||||
2. **PV via SS MMA with SMEM-P is the correct path** — softmax writes P to SMEM in canonical (128, 16) layout, then PV uses SS MMA with A=P(SMEM) × B=V(SMEM) → C=O(TMEM). No TMEM layout issues because both operands are in SMEM with layouts we control.
|
||||
|
||||
### Next session priorities:
|
||||
1. **tcgen05.mma TS debug** (PV GEMM for prefill T>1)
|
||||
2. **HD=128/256** multi-K-tile extension
|
||||
3. **Production kernel** integration into `fmha_sm100.cuh`
|
||||
3. **Per-K-tile P fill** — the (128, 128) canonical P with K-tile offsets has an accumulation bug (first 8 output values lose one K-tile). The fix: reuse a single (128, 16) sPk buffer, filling it from s_p_vals for each K-tile.
|
||||
|
||||
4. **V canonical layout bug** — the original V load code had the MN and K axes swapped in the canonical formula. `d` (HD dim = MN) was mapped to `g_k` and `lr` (sequence position = K) was mapped to `g_mn`. Fix: swap to `g_mn = d/8, g_k = lr/8, llr = d%8, lc = lr%8`. This didn't affect the constant-value tests.
|
||||
|
||||
5. **SMEM allocation bug** — the original test_fmha_smem_p.cu had `TILE_SZ*2` treating BF16 counts as bytes, causing half the needed SMEM. Fix: use `TILE_SZ*sizeof(bf16_t)` or `TILE_SZ*2` with proper byte accounting.
|
||||
|
||||
Let me write a minimal PV GEMM test with `tcgen05.mma TS` to isolate the crash:Misaligned address, not illegal memory access. The TMEM A address format for `tcgen05.mma TS` might differ from what I'm passing. Let me check the PTX spec for the A operand format.
|
||||
### Key design decisions:
|
||||
- **PV SS MMA scale factor**: ~1.0 (not 0.5 like QK). The MMA with BLOCK_MN_A=128, BLOCK_MN_B=16 produces output at approximately 1× the raw dot product.
|
||||
- **Per-K-tile sPk fill**: For decode T=1, only row 0 needs filling (16 values per K-tile). For prefill T>1, all 128 rows need filling from s_p_vals.
|
||||
- **s_p_vals shared memory**: 128 floats (512 bytes) to share softmax output across warps. Avoids recomputing P for each K-tile.
|
||||
|
||||
The key question: what does the TMEM address for tcgen05.mma's A operand look like? For `tcgen05.ld/st`, the format is `tmem_base + (row << 16) + col`. But for `tcgen05.mma`, the C operand is just `tb` (base). The A operand for TS might also just be the base.
|
||||
### Files:
|
||||
- `tests/unit/test_fmha_v5.cu` — Full FMHA HD=16 with PV SS MMA (SMEM-P, per K-tile fill)
|
||||
- `tests/unit/test_pv_ss_128.cu` — PV SS MMA with (128,128) P (accumulation debug, shows the bug)
|
||||
- `tests/unit/test_pv_ss.cu` — Minimal PV SS MMA (A=128×16, B=16×16, standalone)
|
||||
- `tests/unit/test_ss_ts_sequence.cu` — SS+TS sequence test (proved SS+TS coexistence works)
|
||||
|
||||
Let me look at the PTX spec more carefully. The `umma_ts_f16` wrapper passes the A address as the second argument to the inline asm. Let me check if the issue is that `tb` (from tcgen05.alloc) needs to be used differently for TS vs SS.
|
||||
|
||||
Actually, let me check: for SS, the first arg is the TMEM C base. For TS, the first arg is TMEM C base, second is TMEM A base. Both should be simple TMEM column addresses. But maybe the TS instruction requires a different alignment or format.
|
||||
|
||||
Let me try: use a separate TMEM region for C (output) vs A (input):The misaligned address is before the MMA call — it's in the TMEM store. The `tmem_store` with `tb_a + col` where col > 0 might be misaligning. But this works in the QK test where we use `tb + col` with 128 columns. Let me check: maybe the issue is the TMEM alloc size. We allocated 64 columns, and `tb_a + 16` for col=16 should still be within bounds.
|
||||
|
||||
Wait — the printf before MMA didn't appear. The crash is before that. It must be in the TMEM store or load. Let me add more debugging:Column 0 write/read works. The crash is when writing columns 1-15. So `tb_a + col` with col > 0 causes misalignment. The TMEM address for `tmem_store` might need alignment. Let me check: the `tmem_store` uses `16x256b.x1` format. The column address might need specific alignment.
|
||||
|
||||
Wait — `tb_a` is the tmem_base value from `tcgen05.alloc`. Adding `1` to it should give column 1. But maybe the address encoding is different from what I think. Let me check: in the working QK test, we allocated 128 columns and used `tb + n*8` for 32x32b.x8 reads, where n goes 0..15. The address `tb + 0, tb + 8, tb + 16, ...` works fine. But `tb + 1` doesn't work for `tmem_store`.
|
||||
|
||||
This means the 16x256b store has different alignment requirements than the 32x32b load. Let me use 32x32b.x8 stores instead (proven to work in the softmax test):Also remove the column 0 debug write/read since we'll rewrite all columns:**tcgen05.mma TS WORKS!** No crash. The output is 8.0 instead of 16.0 — exactly 2× off. This is the same "unscaled" issue as QK GEMM. The MMA output needs a scale factor. For A=all-1s (128,16) × B=all-1s (16,16), the expected FP32 dot product is 16.0 but we get 8.0. So there's a 2× factor.
|
||||
|
||||
Wait — B is (16,16) with BLOCK_MN=16. The MMA with M=128, N=16 processes 128 rows. But only 16 rows of A have data (the 32x32b.x8 store wrote to ALL 32 rows per column, so all 128 rows of A have 1.0). So the expected result is 16.0 (16 ones × 1.0 = 16.0 per element).
|
||||
|
||||
Getting 8.0 = 16/2. This 2× factor might because the `kind::f16` MMA accumulates in FP16 internally and there's a 2× from the BF16→FP16 conversion. Or the B descriptor with BLOCK_MN=16 is only processing 8 rows of B (half of 16).
|
||||
|
||||
Actually, for BLOCK_MN=16, the descriptor describes B as (16, 16) in canonical layout. But with CORES_MN = 16/8 = 2, there are only 2 core-matrix rows. Each core matrix is 8 rows. 2 × 8 = 16 rows. This should be correct.
|
||||
|
||||
Let me test with known random data and compare to scalar to understand the factor:**This is confusing.** The output shows C[0,0..7] = 16.0 (expected), with max err 0.000000. But the test says "TEST FAILED" with exit code 1. Wait — actually the MMA output is correct: A=1.0 × B=2.0 → C=16.0 (1.0 × 2.0 × 8 = 16.0? Let me compute: C = A × B^T where A is (128, 16) all 1.0, B is (16, 16) all 2.0. Elementwise dot product is sum of 16 ones × 2.0 = 32.0, not 16.0.
|
||||
|
||||
But we got 16.0 — exactly half. So the MMA is dividing by 2 somewhere. This 2× factor is consistent — same as QK where we got 4× factor (2× from K-tile count). The MMA output is scaled down by some factor.
|
||||
|
||||
Actually, for QK we had the "4× bug" which turned out to be the 1/sqrt(HD) scale factor. But here HD=16 → scale=0.25. We got 16.0 instead of 32.0 → factor 0.5. This suggests the MMA output is scaled by 0.5 relative to the raw FP32 dot product.
|
||||
|
||||
The test itself PASSED — the printf shows 16.0 with max err 0.000. But the exit code is 1. Maybe there's a cudaDeviceSynchronize error after the kernel? Let me check:
|
||||
|
||||
The output says "Test completed successfully!" but then "❌ TEST FAILED". This inconsistency suggests the printf is buffered. The test might be crashing after the printf but before exit.
|
||||
|
||||
Let me reduce the test to just printf and exit:Wait, that removes the actual kernel test. Let me just add a proper cleanup after the kernel:
|
||||
|
||||
**tcgen05.mma TS WORKS PERFECTLY!** C[0,0..7] = 16.0 with zero error. The ❌ is from the harness seeing "FAILED" in "Max err from 16.0" — false positive from the word "err". The actual exit code is 0.
|
||||
|
||||
The output is 16.0 = sum of 16 values of (1.0 × 2.0) = 32.0 ÷ 2. So the MMA applies a 0.5 scale factor, same as QK. This is the `kind::f16` MMA scaling — it's just the internal FP16 accumulation factor.
|
||||
|
||||
Now I know TS works. The earlier crash was from using `16x256b.x1` stores to write A to TMEM, which has alignment issues. Using `32x32b.x8` stores works.
|
||||
|
||||
Let me now build the full FMHA with PV GEMM via tcgen05.mma TS for HD=16:
|
||||
|
||||
The TS MMA still crashes with illegal memory access in the full FMHA. But the isolated test worked. Let me check: maybe the V K-tile descriptor with BLOCK_MN=16 is the issue. The isolated test had 16×16 V with BLOCK_MN=16 and it worked. But here we have 8 separate V K-tiles. Let me add printf debugging:
|
||||
|
||||
|
||||
Still crashing. The isolated test worked but this doesn't. The difference is that here we have the full QK+softmax pipeline before PV. Maybe the TMEM state is corrupted by the QK GEMM + softmax. Let me try just the PV GEMM part with pre-written TMEM data:
|
||||
|
||||
Still crashes even with just P written to TMEM. The isolated test worked because it had BLOCK_MN=16 for both A and B, and the C output was 16 columns. But here, the PV GEMM has M=128, N=16. The C output is at `tb_o` (column 128).
|
||||
|
||||
Wait — the TMEM alloc is 256 columns. `tb_o = tb + 128`. Can `tcgen05.mma TS` write to columns starting at `tb + 128`? Let me check: the C TMEM address for the MMA is `tb_o`. The MMA will write to 16 TMEM columns starting from `tb_o`. These should be columns 128-143, which are within the 256-column allocation.
|
||||
|
||||
But wait — the `umma_ts_f16` wrapper's `tmem_c` parameter — is this the same TMEM address format as used by `tcgen05.ld/st`? In the QK GEMM (SS), we pass `tb` directly and it works. For TS, the C parameter should also be a direct TMEM base address.
|
||||
|
||||
The A parameter `tmem_a` — in the isolated test, we passed `tb` and it worked. Here, we pass `tb + kt * MMA_K_BF16`. For kt=0, this is just `tb`, same as the isolated test. Let me try with kt=0 only:
|
||||
|
||||
Still crashes even with single K-tile. The isolated test_mma_ts.cu worked fine. Let me diff the two. The key difference: test_mma_ts had A written by warp 0 with 32x32b.x8 stores (all lanes write 1.0). Here, P is written by warp 0 with 32x32b.x8 stores (lane 0 writes data, other lanes write 0).
|
||||
|
||||
Wait — in test_mma_ts, all 32 lanes wrote 1.0. Here, only lane 0 writes data and other lanes write 0. But only row 0 of P has data — the other rows ARE 0. The TMEM layout should be: column 0 has P[0,0] at lane 0, pos 0 and 0 at all other positions.
|
||||
|
||||
Hmm, maybe the TMEM layout for P (written by 32x32b.x8 stores) is different from what tcgen05.mma TS expects. The 32x32b.x8 store writes in a specific lane-to-row mapping, and the MMA TS reads from a different mapping.
|
||||
|
||||
Let me check: in the isolated test, I wrote A=1.0 to ALL rows (all 32 lanes wrote 1.0). The MMA worked. But here, only lane 0 writes non-zero values. Maybe the issue is that the MMA TS A operand expects the TMEM data in a different layout than what 32x32b.x8 stores produce.
|
||||
|
||||
Actually, the issue might be simpler — the SMEM for V might be overlapping with something, or the V canonical layout is wrong for (16,16) with BLOCK_MN=16.
|
||||
|
||||
Let me try: skip the V load and use the same (16,16) approach as the isolated test:
|
||||
|
||||
Still crashes. The difference from the isolated test must be the SMEM layout or the TMEM O region at tb+128. Let me check: in the isolated test, we wrote O to `tb` (same as A). Here, we write O to `tb_o = tb + 128`. Maybe `tcgen05.mma TS` can't write to a TMEM offset > 127?
|
||||
|
||||
Let me try writing O to tb (overwriting P):
|
||||
|
||||
Still crashes! Even when matching the isolated test exactly (same tb for C, same 16x16 V, same BLOCK_MN=16). The only remaining difference is the SMEM layout. In the isolated test, sV starts right after the alignment pad. Here, sV_base starts after sK (128*16 BF16). Let me check the SMEM alignment:
|
||||
|
||||
The `__cvta_generic_to_shared` conversion + `make_umma_desc_kmajor_none` uses the SMEM address. If sV_base isn't properly aligned for a (16, 16) descriptor with BLOCK_MN=16, it could fail.
|
||||
|
||||
Let me try: allocate a fresh, well-aligned SMEM region for V:
|
||||
|
||||
Still crashes! Even with V properly aligned. Let me think about what's different from the isolated test...
|
||||
|
||||
The only thing left: the isolated test had 128 threads, but only warp 0 was used for TMEM. The P was written by ALL 32 lanes (all wrote 1.0). Here, only lane 0 writes non-zero P values.
|
||||
|
||||
Wait — maybe the issue is that `tcgen05.mma TS` requires the A operand in a specific TMEM layout (Layout D from an MMA output), and the 32x32b.x8 store produces a different layout. In the isolated test, ALL lanes wrote 1.0, which happens to be the same in both layouts. But the address encoding for 32x32b.x8 store (where lane 0 writes row 0) may not match what the MMA expects.
|
||||
|
||||
Actually, I bet the issue is simpler: the TMEM alloc of 256 columns and using tb for the QK output writes 128 columns. The P write uses columns 0-127. But when I then try to do TS MMA, it reads A from columns 0-15 (which should have valid P data) and writes C to columns 0-15 (overwriting part of P). The problem might be that the QK GEMM (SS) setup wrote to TMEM with a different layout than what TS expects.
|
||||
|
||||
Let me try: skip the zero-O step, and use `accumulate=false` for the first TS call. Maybe the zero step is corrupting something.
|
||||
|
||||
Actually, let me just copy the ISOLATED test code into this file exactly and see if it works:
|
||||
|
||||
|
||||
The copy works (the ❌ is from the harness seeing "expect 32.0" as a failure keyword, but exit code is 0 and output is correct). Now let me add the QK GEMM to this file and see if the combination works:
|
||||
|
||||
**The QK GEMM + PV GEMM combination crashes!** Even though each works individually. The QK GEMM writes to TMEM columns 0-127, and then the PV GEMM tries to read from columns 0-15 and write to columns 32-47. The QK GEMM must be leaving the TMEM in a state that's incompatible with the TS MMA.
|
||||
|
||||
The issue: after `tcgen05.mma SS` writes to TMEM, the TMEM layout is in "Layout D" format (the MMA's accumulator format). When we then try to read this with `tcgen05.mma TS` as the A operand, the TS MMA expects the A data in a different layout.
|
||||
|
||||
Wait — the QK SS MMA writes S to TMEM. Then we overwrite columns 0-15 with 32x32b.x8 stores. Then TS MMA reads columns 0-15 as A. The 32x32b.x8 store might produce a different layout than what TS MMA expects.
|
||||
|
||||
But in the isolated test, we also used 32x32b.x8 stores to write A, and TS worked fine. The difference is that here, QK MMA was called first.
|
||||
|
||||
Maybe the QK MMA allocates/reserves TMEM in a way that interferes with the TS MMA. Or maybe there's a hardware state that needs to be cleared between MMA calls.
|
||||
|
||||
Let me try: do the TS MMA FIRST, then the QK MMA:
|
||||
|
||||
**KEY FINDING: PV GEMM works, but QK GEMM AFTER PV crashes!** So the issue is that `tcgen05.mma SS` can't be called after `tcgen05.mma TS` in the same kernel. Or more precisely, calling `tcgen05.mma SS` after `tcgen05.mma TS` with a different TMEM C address crashes.
|
||||
|
||||
This makes sense: the QK MMA writes to `tb` (columns 0-127), but PV TS already wrote to `tb_o` (columns 32-47). When QK SS tries to write to `tb` again, maybe there's a conflict.
|
||||
|
||||
But wait — we also saw QK FIRST + PV SECOND crashes. And PV FIRST + QK SECOND also crashes (on the QK step). So the issue is: **mixing SS and TS MMA in the same kernel doesn't work** when they share TMEM regions.
|
||||
|
||||
Let me test: PV first, QK second, but QK writes to a different TMEM base:
|
||||
### Next steps:
|
||||
1. **Extend to HD=64, HD=128, HD=256** — the per-K-tile approach scales naturally
|
||||
2. **Prefill T>1** — fill all 128 rows of sPk, not just row 0
|
||||
3. **Multi-head support** — per-head launch or head-packed M
|
||||
4. **Production kernel** — integrate into fmha_sm100.cuh
|
||||
5. **Benchmark** — compare SS-PV vs register-math PV at decode T=1
|
||||
|
||||
@@ -132,7 +132,7 @@ test_fmha_v5(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
|
||||
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c]; // MMA scale ~1.0 for PV SS with N=16
|
||||
}
|
||||
if (lane == 0) for (int d=0;d<HD;d++) o_out[d] = f32_to_bf16(o_vals[d]);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user