diff --git a/STAGE_D.md b/STAGE_D.md index a1542383..8a2d993a 100644 --- a/STAGE_D.md +++ b/STAGE_D.md @@ -365,15 +365,24 @@ This is the single most important structural change. Once the kernel can output - [ ] Sink merge in TMEM: `O = (exp(lse1) * O1 + exp(sink) * exp(lse2) * O2) / (exp(lse1) + exp(sink) * exp(lse2))` - [ **Test:** output matches D5b Python merge -### D6 — Mixed-Precision KV Load Path (CG-5) +### D6 — FP4 KV Load Path with On-the-Fly Dequant (MERGED INTO D1) -- [ ] TMA loads FP8 NoPE dims to SMEM slot 0 +**Why D6 is no longer a separate stage:** Designing SMEM-P around BF16 KV and then retrofitting FP4 is the detour trap. FP4 KV at hd=512 shrinks each KV tile 4× vs BF16, which changes the fundamental pipeline depth (kv_stage 2→4-6) and SMEM budget. The kernel we ship will run with FP4 KV — so plan for that architecture now. + +**Paper §2.3.4:** KV cache stores dims 0..447 as FP8 and dims 448..511 as BF16. The paged cache already implements this split (`entries_fp8` + `entries_rope` + `inv_scale`). For FMHA, we take it further: TMA loads FP4 (or FP8) KV to SMEM, dequantize on-the-fly in the SMEM→register path, then MMA. + +**FP4 KV pipeline depth win:** At BF16 hd=512, one K tile = 128 × 512 × 2 = 128 KB. 2 stages = 512 KB (K+V). At FP4 (with FP8 scale overhead): ~36 KB per K tile, same SMEM supports 6+ stages. Each extra stage hides more TMA latency. At 1M-context decode, deeper stages matter a lot. + +**Implementation:** +- [ ] TMA loads FP4 NoPE dims (packed e2m1_x2) to SMEM slot 0 - [ ] TMA loads BF16 RoPE dims to SMEM slot 1 -- [ ] Dequantize FP8 → BF16 in SMEM (vectorized `* inv_scale`) -- [ ] Concatenate [NoPE, RoPE] in SMEM +- [ ] TMA loads FP8 scale factors to SMEM slot 2 +- [ ] Dequantize FP4→BF16 in SMEM (vectorized `* FP8_scale * global_scale`, 16-element microblocks) +- [ ] Concatenate [NoPE, RoPE] in SMEM (or use separate MMA operands) - [ ] MMA reads contiguous BF16 from SMEM -- [ ] **Test:** FP8+BF16 split input matches pure BF16 input (dequant is transparent) -- [ ] **Prerequisite:** D1 (SMEM-P) and D5 (sink merge) working first +- [ ] Verify TMA uses `float4_e2m1fn_x2` element type for FP4 (not uint8) +- [ ] **Test:** FP4+BF16 split input matches pure BF16 input (dequant is transparent) +- [ ] **Prerequisite:** D1.3 (SMEM-P) working at BF16 first for correctness, then add FP4 --- diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 358ba5bb..57faef35 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -242,7 +242,7 @@ class FmhaKernel: tScS = qk_thr.partition_C(cS) tTMEM_LOADcS = thr_load.partition_D(tScS) - # P store atoms (always defined for CuTeDSL scoping; only used when use_smem_p=False) + # P store atoms: TMEM-P (always defined, only used when use_smem_p=False) p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) # Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid) @@ -255,6 +255,20 @@ class FmhaKernel: tScP = cute.make_tensor(tScS.iterator, tScP_layout) tTMEM_STOREcP = thr_store.partition_S(tScP) + # P SMEM copy atoms: SMEM-P (always defined, only used when use_smem_p=True) + # Uses make_tiled_copy_C to partition threads by QK MMA's C-fragment layout. + # Softmax warps have P values in QK C-fragment layout (same as rP_bf16). + # This copy writes those values to sP which has PV A-operand SMEM layout. + smem_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.q_dtype, + num_bits_per_copy=128, + ) + tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) + thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx) + sP_2d = cute.group_modes(sP, 0, 3) # flatten to 2D for copy + tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d) # destination (SMEM) + row_max = -Float32.inf row_sum = Float32(0.0) scale_log2 = Float32(self.scale_softmax_log2)