shit carmine left dangling

This commit is contained in:
2026-05-23 06:55:22 +00:00
parent fe81eba7aa
commit 4eccbb05c1
2 changed files with 30 additions and 7 deletions

View File

@@ -365,15 +365,24 @@ This is the single most important structural change. Once the kernel can output
- [ ] Sink merge in TMEM: `O = (exp(lse1) * O1 + exp(sink) * exp(lse2) * O2) / (exp(lse1) + exp(sink) * exp(lse2))`
- [ **Test:** output matches D5b Python merge
### D6 — Mixed-Precision KV Load Path (CG-5)
### D6 — FP4 KV Load Path with On-the-Fly Dequant (MERGED INTO D1)
- [ ] TMA loads FP8 NoPE dims to SMEM slot 0
**Why D6 is no longer a separate stage:** Designing SMEM-P around BF16 KV and then retrofitting FP4 is the detour trap. FP4 KV at hd=512 shrinks each KV tile 4× vs BF16, which changes the fundamental pipeline depth (kv_stage 2→4-6) and SMEM budget. The kernel we ship will run with FP4 KV — so plan for that architecture now.
**Paper §2.3.4:** KV cache stores dims 0..447 as FP8 and dims 448..511 as BF16. The paged cache already implements this split (`entries_fp8` + `entries_rope` + `inv_scale`). For FMHA, we take it further: TMA loads FP4 (or FP8) KV to SMEM, dequantize on-the-fly in the SMEM→register path, then MMA.
**FP4 KV pipeline depth win:** At BF16 hd=512, one K tile = 128 × 512 × 2 = 128 KB. 2 stages = 512 KB (K+V). At FP4 (with FP8 scale overhead): ~36 KB per K tile, same SMEM supports 6+ stages. Each extra stage hides more TMA latency. At 1M-context decode, deeper stages matter a lot.
**Implementation:**
- [ ] TMA loads FP4 NoPE dims (packed e2m1_x2) to SMEM slot 0
- [ ] TMA loads BF16 RoPE dims to SMEM slot 1
- [ ] Dequantize FP8 → BF16 in SMEM (vectorized `* inv_scale`)
- [ ] Concatenate [NoPE, RoPE] in SMEM
- [ ] TMA loads FP8 scale factors to SMEM slot 2
- [ ] Dequantize FP4→BF16 in SMEM (vectorized `* FP8_scale * global_scale`, 16-element microblocks)
- [ ] Concatenate [NoPE, RoPE] in SMEM (or use separate MMA operands)
- [ ] MMA reads contiguous BF16 from SMEM
- [ ] **Test:** FP8+BF16 split input matches pure BF16 input (dequant is transparent)
- [ ] **Prerequisite:** D1 (SMEM-P) and D5 (sink merge) working first
- [ ] Verify TMA uses `float4_e2m1fn_x2` element type for FP4 (not uint8)
- [ ] **Test:** FP4+BF16 split input matches pure BF16 input (dequant is transparent)
- [ ] **Prerequisite:** D1.3 (SMEM-P) working at BF16 first for correctness, then add FP4
---

View File

@@ -242,7 +242,7 @@ class FmhaKernel:
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# P store atoms (always defined for CuTeDSL scoping; only used when use_smem_p=False)
# P store atoms: TMEM-P (always defined, only used when use_smem_p=False)
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
# Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid)
@@ -255,6 +255,20 @@ class FmhaKernel:
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# P SMEM copy atoms: SMEM-P (always defined, only used when use_smem_p=True)
# Uses make_tiled_copy_C to partition threads by QK MMA's C-fragment layout.
# Softmax warps have P values in QK C-fragment layout (same as rP_bf16).
# This copy writes those values to sP which has PV A-operand SMEM layout.
smem_copy_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.q_dtype,
num_bits_per_copy=128,
)
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx)
sP_2d = cute.group_modes(sP, 0, 3) # flatten to 2D for copy
tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d) # destination (SMEM)
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)