shit carmine left dangling
This commit is contained in:
21
STAGE_D.md
21
STAGE_D.md
@@ -365,15 +365,24 @@ This is the single most important structural change. Once the kernel can output
|
||||
- [ ] Sink merge in TMEM: `O = (exp(lse1) * O1 + exp(sink) * exp(lse2) * O2) / (exp(lse1) + exp(sink) * exp(lse2))`
|
||||
- [ **Test:** output matches D5b Python merge
|
||||
|
||||
### D6 — Mixed-Precision KV Load Path (CG-5)
|
||||
### D6 — FP4 KV Load Path with On-the-Fly Dequant (MERGED INTO D1)
|
||||
|
||||
- [ ] TMA loads FP8 NoPE dims to SMEM slot 0
|
||||
**Why D6 is no longer a separate stage:** Designing SMEM-P around BF16 KV and then retrofitting FP4 is the detour trap. FP4 KV at hd=512 shrinks each KV tile 4× vs BF16, which changes the fundamental pipeline depth (kv_stage 2→4-6) and SMEM budget. The kernel we ship will run with FP4 KV — so plan for that architecture now.
|
||||
|
||||
**Paper §2.3.4:** KV cache stores dims 0..447 as FP8 and dims 448..511 as BF16. The paged cache already implements this split (`entries_fp8` + `entries_rope` + `inv_scale`). For FMHA, we take it further: TMA loads FP4 (or FP8) KV to SMEM, dequantize on-the-fly in the SMEM→register path, then MMA.
|
||||
|
||||
**FP4 KV pipeline depth win:** At BF16 hd=512, one K tile = 128 × 512 × 2 = 128 KB. 2 stages = 512 KB (K+V). At FP4 (with FP8 scale overhead): ~36 KB per K tile, same SMEM supports 6+ stages. Each extra stage hides more TMA latency. At 1M-context decode, deeper stages matter a lot.
|
||||
|
||||
**Implementation:**
|
||||
- [ ] TMA loads FP4 NoPE dims (packed e2m1_x2) to SMEM slot 0
|
||||
- [ ] TMA loads BF16 RoPE dims to SMEM slot 1
|
||||
- [ ] Dequantize FP8 → BF16 in SMEM (vectorized `* inv_scale`)
|
||||
- [ ] Concatenate [NoPE, RoPE] in SMEM
|
||||
- [ ] TMA loads FP8 scale factors to SMEM slot 2
|
||||
- [ ] Dequantize FP4→BF16 in SMEM (vectorized `* FP8_scale * global_scale`, 16-element microblocks)
|
||||
- [ ] Concatenate [NoPE, RoPE] in SMEM (or use separate MMA operands)
|
||||
- [ ] MMA reads contiguous BF16 from SMEM
|
||||
- [ ] **Test:** FP8+BF16 split input matches pure BF16 input (dequant is transparent)
|
||||
- [ ] **Prerequisite:** D1 (SMEM-P) and D5 (sink merge) working first
|
||||
- [ ] Verify TMA uses `float4_e2m1fn_x2` element type for FP4 (not uint8)
|
||||
- [ ] **Test:** FP4+BF16 split input matches pure BF16 input (dequant is transparent)
|
||||
- [ ] **Prerequisite:** D1.3 (SMEM-P) working at BF16 first for correctness, then add FP4
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -242,7 +242,7 @@ class FmhaKernel:
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# P store atoms (always defined for CuTeDSL scoping; only used when use_smem_p=False)
|
||||
# P store atoms: TMEM-P (always defined, only used when use_smem_p=False)
|
||||
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
# Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid)
|
||||
@@ -255,6 +255,20 @@ class FmhaKernel:
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# P SMEM copy atoms: SMEM-P (always defined, only used when use_smem_p=True)
|
||||
# Uses make_tiled_copy_C to partition threads by QK MMA's C-fragment layout.
|
||||
# Softmax warps have P values in QK C-fragment layout (same as rP_bf16).
|
||||
# This copy writes those values to sP which has PV A-operand SMEM layout.
|
||||
smem_copy_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
self.q_dtype,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
|
||||
thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx)
|
||||
sP_2d = cute.group_modes(sP, 0, 3) # flatten to 2D for copy
|
||||
tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d) # destination (SMEM)
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
Reference in New Issue
Block a user