- (128,64) PV MMA A-fragment has N_MMA=64, reads P with wrong stride - Softmax writes P with QK C-fragment layout (N_MMA=128) - O[m,d] ≈ P[m,2d] — every other column effect confirmed - All-ones and single-element V pass (uniform/sparse data hides mismatch) - epi_tile must use PV cta_tile (partial fix: 0.01 → 0.876) - Added footguns #9 (TMEM alias N_MMA match) and #10 (epi_tile) - Added diagnostic test results to test table
245 lines
13 KiB
Markdown
245 lines
13 KiB
Markdown
# DeepSeek-V4 NVFP4 Kernel Suite
|
||
|
||
CuTeDSL kernels for DeepSeek-V4 (Blackwell B200, SM100). All kernels use `cutlass.cute` (CuTeDSL) with Blackwell tensor cores.
|
||
|
||
## Status (May 21, 2026 — 05:15 UTC)
|
||
|
||
### ✅ Stage A: Bare Q@K^T via tcgen05.mma → TMEM → GMEM — COMPLETE
|
||
|
||
**File**: `tests/test_stage_a_v2.py`
|
||
**Result**: Q(128,128) @ K^T(128,128) → S(128,128), cosine 0.999999
|
||
|
||
### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS
|
||
|
||
**Pipeline deadlock: FIXED. Kernel runs without deadlock.**
|
||
**Bug 1 (V MN-major): ✅ Fix applied.**
|
||
**Bug 2 (softmax packing): ✅ Confirmed correct (V=I test: cosine 1.0).**
|
||
**Bug 3 (ACCUMULATE): ✅ Fix applied.**
|
||
**Bug 4 (non-square PV): 🔨 ROOT CAUSE IDENTIFIED — TMEM layout mismatch.**
|
||
|
||
#### Bug 4 (CURRENT): PV MMA Broken for (128,64) Output — ROOT CAUSE IDENTIFIED
|
||
|
||
**Root Cause: The (128,64) PV MMA's A-fragment reads P from TMEM with a different layout than the softmax packing writes it.**
|
||
|
||
The softmax packing writes P using the **QK C-fragment layout** (MMA atom = (128,128,16), N_MMA=128). The PV MMA reads P using its **A-fragment layout** (MMA atom = (128,64,16), N_MMA=64). These two layouts produce different physical TMEM addresses for the same logical (m,k) coordinate.
|
||
|
||
**Evidence:**
|
||
- Truncated identity V (64,128) MN-major: O[m,d] ≈ P[m, 2d] — the MMA reads every other column of P
|
||
- All-ones V: cosine 0.999999 ✅ (uniform data hides the layout mismatch)
|
||
- Single-element V: cosine 1.0 ✅ (sparse data also hides it)
|
||
- (128,128) PV with same softmax packing: cosine 0.999999 ✅ (N_MMA=128 matches QK, no mismatch)
|
||
|
||
**C++ TMEM Fragment Layout (from mma_traits_sm100.hpp):**
|
||
```cpp
|
||
// For M_MMA = 128, N_MMA varies with the MMA atom's N dimension
|
||
Layout tmem_atom = Layout<Shape <_128, Int<N_MMA>>,
|
||
Stride< _1, _128>>{};
|
||
```
|
||
- QK C-fragment: N_MMA=128 → 128 TMEM columns, stride 128
|
||
- PV A-fragment (128,64): N_MMA=64 → 64 TMEM columns, stride 128
|
||
|
||
When the softmax packing writes P at `tmem_p0_offset` using the QK C-fragment layout (N_MMA=128), P's (m,k) elements land at TMEM address `m + 128*k`. But the PV A-fragment (N_MMA=64) reads the same TMEM region as if P were stored with N_MMA=64, so it interprets the data with stride 64 instead of 128, causing the every-other-column effect (O[m,d] ≈ P[m, 2d]).
|
||
|
||
**Fix (not yet applied): The softmax packing must write P using the PV MMA's A-fragment layout, not the QK C-fragment layout.** FMHA does this correctly because its softmax writes P using a composition that matches the PV A-fragment — the `tStS_P` layout is derived from `tStS.layout` (QK C-fragment) but the TMEM store uses a C-fragment composition that's based on the PV MMA's tiling. The key is that FMHA's `tilePlikeFP32` computation adapts the packing width to match the PV output N.
|
||
|
||
**Additional fix: `epi_tile` must be computed from PV cta_tile, not QK cta_tile.** Using QK's cta_tile for the epilogue produces `epi_tile=(128,128)` which is wrong for a (128,64) output. Computing from PV's cta_tile gives `epi_tile=(128:1, 32:1)`. This fix alone improved cosine from 0.01 to 0.876, but the TMEM layout mismatch remains.
|
||
|
||
**V SMEM Layouts (confirmed correct):**
|
||
- `PV(128,64) V SMEM: outer=((64,16),1,8,1):((1,64),0,1024,0), inner=S<3,4,3>`
|
||
- `PV(128,128) V SMEM: outer=(((64,2),16),1,8,1):(((1,8192),64),0,1024,0), inner=S<3,4,3>`
|
||
|
||
---
|
||
|
||
### Bug 1: V B-Operand Must Be MN-Major — ✅ FIX APPLIED
|
||
|
||
V must be shaped (head_dim, seq) = (64, 128) with strides (1, 64) — MN-major.
|
||
PV MMA uses `v_major` (OperandMajorMode.MN) instead of `b_major` (K).
|
||
|
||
V must use `as_strided` — default PyTorch (64,128) gives strides (128,1) which is K-major.
|
||
|
||
### Bug 2: C-Fragment Composition Store — ✅ CONFIRMED CORRECT
|
||
|
||
FP32→BF16 packing via C-fragment composition store (FMHA pattern) is correct.
|
||
Proven by V=I test (cosine 1.0) and random V 128x128 test (cosine 0.999999).
|
||
|
||
⛔ **FOOTGUN**: `St32x32bOp` MUST use Float32, NOT BFloat16.
|
||
⚠️ The recast view for P packing uses the LOAD layout (128 BF16 elements), not the store composition shape.
|
||
|
||
### Bug 3: First PV Must Use ACCUMULATE=False — ✅ FIX APPLIED
|
||
|
||
If ACCUMULATE=True on the first PV, `O = P@V + old_O` adds uninitialized TMEM. Always ACCUMULATE=False for first PV, then True for subsequent tiles.
|
||
|
||
---
|
||
|
||
## 🔨 Stage C: Online Softmax — AFTER B
|
||
|
||
Per the pseudocode: epilogue warps compute per-row tile_max, rescale, exp, store P back to TMEM.
|
||
|
||
## 🔨 Stage D: FP8 Paged KV Gather — AFTER C
|
||
|
||
Replace BF16 TMA load with FP8 paged KV gather + per-position dequant.
|
||
|
||
---
|
||
|
||
## Pipeline Deadlock — ✅ FIXED (May 21)
|
||
|
||
v20-v25 all deadlocked on GPU. Three root causes found and fixed:
|
||
|
||
### Fix 1: PipelineUmmaAsync for mma_si Must NOT Pass cta_layout_vmnk
|
||
|
||
FMHA's mma_s0/mma_s1 PipelineUmmaAsync calls do NOT pass cta_layout_vmnk. Removing it fixes the deadlock.
|
||
|
||
### Fix 2: TMA Warp Must NOT Call tmem.wait_for_alloc()
|
||
|
||
The tmem allocation barrier has `num_threads = 32 * (mma_warp + epilogue_warps)`. The TMA warp is NOT part of this barrier. Calling `wait_for_alloc()` from the TMA warp corrupts the barrier.
|
||
|
||
### Fix 3: PipelineTmaStore (not TmaStorePipeline)
|
||
|
||
`pipeline.TmaStorePipeline` does not exist. The correct name is `pipeline.PipelineTmaStore`.
|
||
|
||
---
|
||
|
||
## ⛔ DEAD TEST: test_stage_b_v21.py — DELETED, DO NOT RECREATE
|
||
|
||
v21 attempted both Bug 1 and Bug 2 fixes in a hand-rolled pipeline kernel. It deadlocks on GPU. Root cause: pipeline synchronization mismatch. **Do not recreate.** Write from scratch using fmha.py as the reference.
|
||
|
||
---
|
||
|
||
## ⛔ FOOTGUNS — CUTLASS CuTeDSL Landmines
|
||
|
||
### 1. St32x32bOp with 16-bit dtype → ILLEGAL MEMORY ACCESS
|
||
|
||
`St32x32bOp(Repetition(N), BFloat16)` crashes at runtime. You MUST use `St32x32bOp(Repetition(N), Float32)` and pack 2×16-bit values into 1×Float32 backing words via `cute.recast_ptr`. The 16-bit type only appears in the recast view, never in the store atom itself.
|
||
|
||
### 2. V B-Operand Major Mode ≠ K Major Mode
|
||
|
||
FMHA requires `v_major_mode == OperandMajorMode.MN`. Passing K's K-major mode for V is WRONG. V must be shaped (head_dim, seq) with strides (1, head_dim) to produce MN-major. Standard PyTorch row-major (seq, head_dim) gives K-major.
|
||
|
||
### 3. CuTe Nested Layout Modes Flatten Sequentially
|
||
|
||
A layout like `((128,16),1,(4,2)):((65536,1),0,(16,64))` looks "non-sequential" but flattens to `addr = m*65536 + k` when k = k0 + 16*k1 + 64*k2 (CuTe row-major order). Do NOT assume nested modes imply non-sequential physical addressing. The C-fragment composition and A-fragment alias the same TMEM columns — BUT ONLY WHEN N_MMA MATCHES (i.e., (128,128) PV). For (128,64) PV, N_MMA=64 and the alias breaks.
|
||
|
||
### 4. PipelineUmmaAsync Consumer Group = Thread Count, NOT Warp Count
|
||
|
||
```python
|
||
# WRONG: consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4)
|
||
# CORRECT: consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(warp_ids))
|
||
```
|
||
|
||
### 5. PipelineUmmaAsync for mma_si Must NOT Pass cta_layout_vmnk
|
||
|
||
Passing `cta_layout_vmnk` to the mma_si PipelineUmmaAsync causes deadlock. FMHA does not pass it. Remove it.
|
||
|
||
### 6. TMA Warp Must NOT Call tmem.wait_for_alloc()
|
||
|
||
The tmem allocation barrier only includes MMA + epilogue warps. The TMA warp is excluded. Calling `wait_for_alloc()` from the TMA warp corrupts the barrier.
|
||
|
||
### 7. PV MMA ACCUMULATE Must Be False on First Tile
|
||
|
||
If ACCUMULATE=True on the first PV MMA, `O = P@V + old_O` adds uninitialized TMEM to the result. Always set ACCUMULATE=False for the first PV, then True for subsequent tiles. FMHA: `pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0)`.
|
||
|
||
### 8. TMEM Pointer Arithmetic: Offset Units Depend on Pointer Type
|
||
|
||
When computing PV A-fragment offset from the softmax P offset:
|
||
```python
|
||
# Softmax store: FP32 pointer + tmem_p0_offset (in FP32 elements)
|
||
tStS_P = cute.make_tensor(tStS.iterator + tmem_p0_offset, tStS_P_layout)
|
||
|
||
# PV A-fragment: BF16 pointer + scaled offset (in BF16 elements)
|
||
p_offset = acc_dtype.width // q_dtype.width * tmem_p0_offset # 2 * 32 = 64
|
||
tOrP0 = cute.make_tensor(tOrP.iterator + p_offset, tOrP.layout)
|
||
```
|
||
Both must address the same physical TMEM column. The 2× scaling accounts for FP32→BF16 element size difference.
|
||
|
||
### 9. C-Fragment → A-Fragment TMEM Alias Only Works When N_MMA Matches
|
||
|
||
The softmax packing writes P using the QK C-fragment layout. The PV A-fragment reads P. These alias correctly ONLY when both MMA atoms have the same N_MMA (i.e., both (128,128,16) → N_MMA=128). When the PV MMA uses (128,64,16) → N_MMA=64, the A-fragment has a different TMEM stride and reads garbage. **The softmax packing must be adapted to write P in the PV A-fragment's layout.**
|
||
|
||
### 10. epi_tile Must Match PV Output Shape, Not QK
|
||
|
||
`compute_epilogue_tile_shape` must use PV's `cta_tile_shape_mnk`, not QK's. Also, `self.cta_tile_shape_mnk` must be set to PV's cta tile before calling `epilogue_tma_store` (it reads `gemm_kernel.cta_tile_shape_mnk` internally). FMHA sets `self.epi_tile = self.pv_mma_tiler[:2]` directly.
|
||
|
||
---
|
||
|
||
## Architecture: Per-Tile Flow
|
||
|
||
```
|
||
For each KV tile:
|
||
1. Load warp writes sKV[stage] (paged FP8 gather via indexed cp.async)
|
||
2. MMA warp issues MMA1: sQ @ sKV[stage]^T → tmem_scores (accumulate=False)
|
||
Signals scores_full_mbar (via PipelineUmmaAsync commit)
|
||
3. Epilogue warps wait on mma_si consumer (scores ready), then:
|
||
a. tcgen05.ld scores from TMEM → register fragments
|
||
b. Compute tile_max, new_max, rescale = exp(old_max - new_max)
|
||
c. Apply rescale to tmem_output IN PLACE (tmem_output *= rescale)
|
||
d. tcgen05.st exp(scores - new_max) back to TMEM → P operand (via C-fragment composition)
|
||
e. Release mma_si (softmax_done — MMA warp can re-acquire and issue PV MMA)
|
||
4. MMA warp waits on mma_si acquire (softmax done), MMA2: P @ sV → tmem_output (accumulate=True)
|
||
5. Stage released, load warp can refill it
|
||
|
||
After all tiles: epilogue warps tcgen05.ld tmem_output, divide by row_sum, cast to BF16, store to GMEM
|
||
```
|
||
|
||
---
|
||
|
||
## Test Results
|
||
|
||
| File | Description | Cosine | Status |
|
||
|------|-------------|--------|--------|
|
||
| `test_stage_a_v2.py` | Q@K^T only | 0.999999 | ✅ PASS |
|
||
| `test_mma_si_only.py` | Q@K^T + mma_si pipeline (no PV) | 0.999999 | ✅ PASS |
|
||
| `test_softmax_only.py` | Q@K^T + softmax packing, output S | 0.52 | ❌ S overwritten by P (expected) |
|
||
| `test_mma_si_pv.py` | Q@K^T + softmax + P@V (V MN-major, 128x64) | 0.01 | ❌ PV output garbage |
|
||
| `test_pv_diag.py` | Q@K^T + softmax + P@V (V=I 128x128) | 1.0 | ✅ PASS |
|
||
| `test_pv_diag.py` | Q@K^T + softmax + P@V (random V 128x128) | 0.999999 | ✅ PASS |
|
||
| `test_diag_v_truncid.py` | Q@K^T + softmax + P@V (trunc identity 64x128, epi from PV) | 0.02 | ❌ O[m,d]≈P[m,2d] — TMEM alias mismatch |
|
||
| `test_diag_v_ones.py` | All-ones V (64x128) | 0.999999 | ✅ uniform data hides mismatch |
|
||
| `test_diag_v_ones.py` | Single-element V (64x128) | 1.0 | ✅ sparse data hides mismatch |
|
||
| `test_diag_layout.py` | (128,64) PV with epi from PV cta_tile | 0.876 | ❌ partial fix — epi correct, TMEM alias still broken |
|
||
| `test_diag_smem_layout.py` | Print V SMEM layouts for (128,64) vs (128,128) | N/A | ℹ️ layouts confirmed correct |
|
||
| `test_layout_compare.py` | Print TMEM layouts for QK S and PV A-fragment | N/A | ℹ️ layout inspection |
|
||
|
||
---
|
||
|
||
## Critical APIs & Lessons
|
||
|
||
### TMEM offset arithmetic
|
||
- `find_tmem_tensor_col_offset(fragment)` — returns physical TMEM column count
|
||
- QK accumulator: 128 TMEM columns
|
||
- A-fragment offset: `acc_dtype.width // q_dtype.width * tmem_p0_offset` (F32/BF16=2)
|
||
|
||
### pv_mma_tiler — FMHA Convention
|
||
```python
|
||
pv_mma_tiler = (qk_mma_tiler[0], qk_mma_tiler[2], qk_mma_tiler[1])
|
||
# = (M, head_dim, QK_N) = (128, 64, 128) for head_dim=64
|
||
```
|
||
|
||
FMHA passes `pv_mma_tiler[:2] = (128, head_dim)` to `make_trivial_tiled_mma`, NOT the QK tiler `(128, 128)`.
|
||
|
||
### make_trivial_tiled_mma — Use New Overload
|
||
```python
|
||
make_trivial_tiled_mma(a_dtype, b_dtype, a_leading_mode, b_leading_mode,
|
||
acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM)
|
||
```
|
||
|
||
### 3D tensors required
|
||
Tensors must be 3D (M, K, L) for `cute.local_tile` — add L=1 dimension.
|
||
|
||
### Other APIs
|
||
1. `cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=...)` — CuTe tensor from PyTorch
|
||
2. `PipelineTmaUmma.create(...).make_participants()` — returns (producer, consumer) pair
|
||
3. `utils.gemm.sm100.epilogue_tma_store` — handles transform + partition/dcopy. DO NOT hand-roll.
|
||
4. `smem.allocate_tensor()` — for SMEM tensors
|
||
5. `LayoutEnum.from_tensor(a).mma_major_mode()` — major mode from cute tensor
|
||
|
||
## Environment
|
||
|
||
- **Server**: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
|
||
- **venv**: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
|
||
- **PYTHONPATH**: `/root/dsv4-nvfp4-workspace/kernel`
|
||
- **Model**: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
|
||
- **vLLM repo**: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
|
||
- **Pseudocode**: `/root/fragile-kernel-example/README.md`
|
||
- **fmha.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
|
||
- **fmha_bwd.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha_bwd.py`
|