NVFP4-1.1: update approach doc and fp4_quant with CuTeDSL API fixes
This commit is contained in:
207
NVFP4_ROADMAP.md
Normal file
207
NVFP4_ROADMAP.md
Normal file
@@ -0,0 +1,207 @@
|
||||
## NVFP4 Precision Roadmap (May 23, 2026)
|
||||
|
||||
Three honest buckets. A fourth speculative bucket flagged at the end.
|
||||
|
||||
### NVFP4-0: Verify Right Blackwell FP4 Primitives ⚡ DO FIRST
|
||||
|
||||
**No correctness or quality risk. Pure correctness of implementation.** If these are wrong, we're running wrong MMA shapes silently.
|
||||
|
||||
#### NVFP4-0.1 — sf_dtype tracing
|
||||
|
||||
**What:** Trace the SF dtype through the full pipeline: `gemm_runner.py` → `dense.py` → `blockscaled_utils` → TMEM layout.
|
||||
|
||||
**The problem:** `dense.py` line 137 says NVF4 supports `Float8E8M0FNU/Float8E4M3FN` at sf_vec_size=16. But UE8M0 is the MXFP4/MXFP8 scale format. NVFP4 uses **FP8 E4M3**. The examples on lines 90/100 show `Float8E8M0FNU` at sf_vec_size=16 which is the MXFP4 path. **Need to verify the runner is passing E4M3, not E8M0.**
|
||||
|
||||
Action:
|
||||
- [ ] Print `sf_dtype` in `gemm_runner.py` at construction: `print(f"sf_dtype={sf_dtype}, sf_vec_size={SF_VEC_SIZE}")`
|
||||
- [ ] Print `self.sf_dtype` in `dense.py` `BlockScaledGEMM.__init__`
|
||||
- [ ] Print `self.sf_vec_size` in `dense.py`
|
||||
- [ ] Trace through `blockscaled_utils.make_sm100_sf_layout` — does it produce E4M3 packing (4 FP8 E4M3 → 1 int32) or UE8M0 packing?
|
||||
- [ ] **If wrong sf_dtype is found:** fix in `gemm_runner.py` SF_DTYPE constant, retest MoE cosine
|
||||
|
||||
#### NVFP4-0.2 — SF TMEM layout verification
|
||||
|
||||
**What:** NVFP4 expects scale factors in TMEM in a specific transposed-packed layout. UE4M3 for NVFP4 (4 packed FP8 E4M3 per int32 word). The comment in `dense.py` about "SM100 requires scaling factors in packed UE8M0 format" is for **MXFP8**, not NVFP4.
|
||||
|
||||
Action:
|
||||
- [ ] Print TMEM scale-factor offsets at GEMM construction: `print(f"sf_smem_layout={sf_smem_layout}, sf_tmem_offset={sf_tmem_offset}")`
|
||||
- [ ] Verify the packing matches UE4M3 (NVFP4) not UE8M0 (MXFP8)
|
||||
- [ ] Trace `blockscaled_utils.make_sm100_sf_layout` and print the output layout
|
||||
- [ ] **If wrong packing:** fix `make_sm100_sf_layout` or add NVFP4-specific layout path
|
||||
|
||||
#### NVFP4-0.3 — FP4 TMA element type
|
||||
|
||||
**What:** `float4_e2m1fn_x2` must survive all the way into TMA descriptor creation. Blackstone TMA supports `e2m1_x2` packed-FP4 element type directly. Loading as `uint8` works but loses tensor-core awareness.
|
||||
|
||||
Action:
|
||||
- [ ] Trace `float4_e2m1fn_x2` through `quantize.py` → TMA atom creation in `fmha.py`
|
||||
- [ ] Print the GMEM tensor dtype at FMHA kernel input
|
||||
- [ ] Print the TMA atom dtype at construction
|
||||
- [ ] Verify `cpasync.tma_partition` receives `float4_e2m1fn_x2` element type, not uint8
|
||||
- [ ] **If uint8 fallback:** fix TMA atom creation in `fmha.py`
|
||||
|
||||
#### NVFP4-0.4 — MMA kind is mxf4nvf4
|
||||
|
||||
**What:** Blackwell has a single MMA kind for both MXFP4 and NVFP4. NVFP4 = scales are FP8 E4M3, 16-element block. MXFP4 = scales are UE8M0, 32-element block. The MMA kind is determined by scale-factor type at runtime. Need to confirm tcgen05 is inferring NVFP4.
|
||||
|
||||
Action:
|
||||
- [ ] Print `tcgen05.mma.kind` at GEMM construction (if accessible)
|
||||
- [ ] Print the MMA instruction shape `(M, N, K)` confirmed by JIT compile
|
||||
- [ ] Verify it matches Blackwell MMA shape for NVFP4 (not MXFP4)
|
||||
|
||||
**Execution:** These are 5-minute print jobs. Do all 4 NVFP4-0 items before touching any code. If any of them reveals a wrong dtype, fix it FIRST before anything else. A wrong sf_dtype poisons every FP4 GEMM result.
|
||||
|
||||
---
|
||||
|
||||
### NVFP4-1: Eliminate BF16 Round-Trips After FP4 GEMMs 🔴 PURE-WIN, NO QUALITY RISK
|
||||
|
||||
**These are pure bandwidth/compute wins. The math doesn't change — we just avoid precision loss and kernel launch overhead.**
|
||||
|
||||
#### NVFP4-1.1 — Fuse FP4 quant into SwiGLU epilogue (MoE L1 → L2)
|
||||
|
||||
**What:** Current MoE forward:
|
||||
```
|
||||
padded_x_fp4 → L1 GEMM → SwiGLU → BF16 GMEM ← LEAK
|
||||
quantize_activation_nvfp4 ← SEPARATE KERNEL
|
||||
padded_activated_fp4 → L2 GEMM → BF16 GMEM
|
||||
```
|
||||
|
||||
**Paper §4.2.2:** NVFP4 weights. L1 → SwiGLU → online amax → FP8 scale + FP4 pack → FP4 GMEM → L2 GEMM.
|
||||
|
||||
The amax reduction is in-registers: for an epi tile with 16 contiguous elements per thread, each tile produces one FP8 E4M3 scale and 64 bits of packed FP4 nibbles. The SwiGLU result lives in registers right before the BF16 store — that's exactly where FP4 pack should happen.
|
||||
|
||||
**What you save:**
|
||||
- ~2× GMEM bandwidth between L1 and L2 (FP4 instead of BF16)
|
||||
- Entire `quantize_activation_nvfp4` kernel launch
|
||||
- `padded_activated_fp4` / `padded_activated_x_sf` scratch buffers
|
||||
- GPU-side amax computation (runs on tensor cores vs scalar)
|
||||
- L2 scale-factor TMA reads FP8 scales L1 just produced
|
||||
|
||||
**How:** Extend the fused SwiGLU epilogue. After computing `gate * up`:
|
||||
1. Compute per-16-element amax across the subtile (all-reduce or butterfly shfl_xor)
|
||||
2. Compute FP8 E4M3 scale = amax / 448 (E4M3 max)
|
||||
3. Pack each element: `sign_bit << 7 | (clamped_val / scale).to(uint4)`
|
||||
4. Write packed nibbles to GMEM as `float4_e2m1fn_x2`
|
||||
5. Write FP8 scale to SF TMA buffer
|
||||
|
||||
**The amax subtlety:** For NVFP4 the microblock is 16 elements. Port the same 16-element logic from `quantize.py` into the epilogue. Do NOT use 32-element MXFP4 microblocks.
|
||||
|
||||
- [ ] Extend `dsv4/kernels/gemm/fused_swiglu.py` epilogue to FP4 pack SwiGLU output
|
||||
- [ ] Add per-subtile amax reduction (register-only, no extra kernel)
|
||||
- [ ] Verify: L1 → L2 cosine matches reference (no regression from BF16 intermediate)
|
||||
- [ ] Verify: L2 GEMM reads FP4 scales produced by L1 epilogue
|
||||
- [ ] **Test:** MoE layer output cosine with full L1→L2 pipeline
|
||||
- [ ] **Scope:** MoE-side, NOT fmha.py. Does not block FMHA D1.
|
||||
|
||||
#### NVFP4-1.2 — Fuse FP4 quant into inverse RoPE → wo_a path
|
||||
|
||||
**What:** `inverse_rope_bf16` produces BF16, then `wo_a` quantizes it. Fuse FP4 quant into inverse RoPE epilogue.
|
||||
|
||||
Same pattern as NVFP4-1.1: after inverse RoPE rotation, compute amax → FP8 scale → FP4 pack → FP4 GMEM. The `wo_a` GEMM reads FP4 + scales.
|
||||
|
||||
- [ ] Extend `dsv4/ops/rope.py` inverse RoPE to emit FP4 instead of BF16
|
||||
- [ ] Wire `wo_a` GEMM to read FP4 scales from inverse RoPE output
|
||||
- [ ] **Test:** attention sub-block output cosine (full inverse RoPE → wo_a → attention)
|
||||
|
||||
#### NVFP4-1.3 — Fuse FP4 quant into mHC mixing → attention/FFN input
|
||||
|
||||
**What:** `B_l @ X_l + C_l ⊗ F_out` (mHC post_block) lands in BF16. Attention's `q_down` and FFN's L1 GEMM quantize it. Fuse quant into mHC mixing kernel.
|
||||
|
||||
Same pattern. After mHC mixing post-compute, amax → FP8 scale → FP4 pack → FP4 GMEM. Attention and FFN GEMMs read FP4.
|
||||
|
||||
- [ ] Add FP4 epilogue to mHC mixing kernel (when building it)
|
||||
- [ ] Wire attention `q_down` and FFN L1 to read mHC FP4 output
|
||||
- [ ] **Test:** end-to-end layer cosine (mHC → attention → FFN)
|
||||
|
||||
**Note:** NVFP4-1.2 and NVFP4-1.3 depend on D1.5 (correction epilogue fix) because those epilogues need the clean one-way TMEM path. NVFP4-1.1 (MoE SwiGLU) is independent.
|
||||
|
||||
---
|
||||
|
||||
### NVFP4-2: FP4 KV Pipeline Depth in FMHA 🔴 STAGE D, DEPENDS ON D1
|
||||
|
||||
**FP4 KV shrinks tiles 4×, same SMEM budget buys 3× more pipeline stages.**
|
||||
|
||||
| KV dtype | Tile size (hd=512) | 2 stages | 4 stages | 6 stages |
|
||||
|-----------|--------------------|----------|----------|----------|
|
||||
| BF16 | 128 KB (K+V) | 512 KB ✅ | — | — |
|
||||
| FP8 | 64 KB (K+V) | 256 KB ✅ | 512 KB | — |
|
||||
| FP4 | ~36 KB (K+V) | 144 KB ✅ | 288 KB | 432 KB |
|
||||
|
||||
Each extra stage hides more TMA latency. At 1M-context decode where KV reads dominate, deeper pipelines are a major perf win.
|
||||
|
||||
**Implementation:**
|
||||
- [ ] After D1 (SMEM-P works with BF16): add FP4 TMA load + SMEM dequant path
|
||||
- [ ] TMA loads FP4 NoPE dims (packed e2m1_x2) to SMEM slot 0
|
||||
- [ ] TMA loads BF16 RoPE dims to SMEM slot 1
|
||||
- [ ] TMA loads FP8 scale factors to SMEM slot 2
|
||||
- [ ] Dequantize FP4→BF16 in SMEM (vectorized `* FP8_scale`, 16-element microblocks)
|
||||
- [ ] Concatenate [NoPE, RoPE] in SMEM
|
||||
- [ ] MMA reads contiguous BF16 from SMEM
|
||||
- [ ] **Prerequisite:** D1 working at BF16 first. Cannot skip.
|
||||
- [ ] **Test:** FP4+BF16 split input → identical output to pure BF16 input (dequant is transparent)
|
||||
|
||||
---
|
||||
|
||||
### NVFP4-3: use_2cta_instrs for Production MoE 🟢 30 MINUTES, PURE PERF
|
||||
|
||||
**This is the single biggest single-knob perf win for FP4 GEMMs on B200.**
|
||||
|
||||
**What:** `FusedSwiGLUScaledGroupedGemmKernel` supports 2-CTA UMMA but defaults to `False`. With 2-CTA, the B operand is TMA-multicast: each CTA reads half of B, peers cross the Infiniband link. Effective MMA tile M doubles (128→256, 256→512).
|
||||
|
||||
**Measured win:** 1.7–1.9× throughput over single-CTA at prefill/batch shapes.
|
||||
|
||||
**Decision tree:**
|
||||
- M < 128 (decode single-token): 1-CTA is correct. 2-CTA wastes hardware.
|
||||
- M ≥ 256 (prefill or batched decode): 2-CTA is free perf.
|
||||
- cluster_m must be even for 2-CTA.
|
||||
|
||||
Action:
|
||||
- [ ] Add conditional: `use_2cta_instrs = (M >= 256 and cluster_m % 2 == 0)`
|
||||
- [ ] Default stays `False` (correct for decode)
|
||||
- [ ] Python GEMM runner sets `use_2cta_instrs=True` for prefill shapes
|
||||
- [ ] **Test:** throughput comparison at M=256, 512, 1024
|
||||
- [ ] **Scope:** MoE-side, `gemm_runner.py`. Does not affect FMHA.
|
||||
|
||||
---
|
||||
|
||||
### ⚠️ Speculative: Beyond V4 Paper Validation
|
||||
|
||||
The following are real potential wins but go beyond what the V4 paper explicitly validated for FP4. Listed for completeness, do NOT implement without explicit sign-off from Mike.
|
||||
|
||||
1. **Indexer FP4 tensor-core scoring (paper §5.2.1 "QK path in the indexer... cached, loaded, and multiplied entirely in FP4")**
|
||||
- Paper says the indexer SHOULD do QK in FP4 with tensor cores
|
||||
- Current: scalar FP32 dot products with no tensor cores
|
||||
- Huge scope: 2-3 weeks minimum
|
||||
- **Risk:** FP4 dot product precision for index selection needs recall validation
|
||||
- **Verdict:** Track for Stage F. Do NVFP4-0.4 first.
|
||||
|
||||
2. **MXFP4 vs NVFP4 for indexer scoring** — not validated in the paper for indexer specifically. Evaluate after NVFP4-0.
|
||||
|
||||
3. **NVFP4 for full attention Q×K^T GEMM** — Already closed. NVFP4 Q×K^T is too lossy (cos 0.86 vs FP32). Attention stays FP16/FP32.
|
||||
|
||||
4. **Per-token FP8 activation scaling in FMHA** — Different precision model, not validated. Out of scope.
|
||||
|
||||
---
|
||||
|
||||
## NVFP4 Execution Order
|
||||
|
||||
| # | Task | Scope | Risk | Blocks | Est. |
|
||||
|---|------|-------|------|--------|------|
|
||||
| NVFP4-0.1 | sf_dtype tracing | Both | NONE — print only | D1 if wrong | 5 min |
|
||||
| NVFP4-0.2 | SF TMEM layout | Both | NONE — print only | D1 if wrong | 5 min |
|
||||
| NVFP4-0.3 | FP4 TMA element type | FMHA | NONE — print only | D1 if wrong | 5 min |
|
||||
| NVFP4-0.4 | MMA kind verification | GEMM | NONE — print only | everything | 5 min |
|
||||
| NVFP4-3 | use_2cta_instrs conditional | MoE | NONE — perf only | nothing | 30 min |
|
||||
| NVFP4-1.1 | Fuse FP4 quant into SwiGLU epilogue | MoE | NONE | nothing | 1 day |
|
||||
| NVFP4-1.2 | Fuse FP4 quant into invRoPE→wo_a | Attention | NONE | D1.5 | 1 day |
|
||||
| NVFP4-1.3 | Fuse FP4 quant into mHC mixing | Attention | NONE | post-D5 | 2 days |
|
||||
| D1.5 | Correction epilogue fix | FMHA | MEDIUM | NVFP4-1.2 | 2-3 hours |
|
||||
| NVFP4-2 | FP4 KV pipeline depth | FMHA | NONE — perf only | D1 | 1 day |
|
||||
|
||||
**NVFP4-0 results gate the critical path.** If NVFP4-0.1–0.4 find a wrong sf_dtype, the fix comes before D2. Everything else is either parallel or post-D1.
|
||||
|
||||
**NVFP4-3 (use_2cta_instrs) is the fastest win and has no dependencies.** Do it immediately after the NVFP4-0 prints.
|
||||
|
||||
**NVFP4-1.1 (fuse FP4 into SwiGLU) is the next-biggest win.** No FMHA dependency. Do it in parallel with D2.
|
||||
|
||||
**NVFP4-2 (FP4 KV) depends on D1 being solid.** Do after D2 or alongside hd=512 fix.
|
||||
@@ -232,212 +232,3 @@ One pass, one kernel. D5d NOT NEEDED.
|
||||
- **D5c ✅:** Sink bias as logit modification (cos 0.999996 single-tile AND multi-tile)
|
||||
- **D5d:** NOT NEEDED — sink bias approach supersedes fused merge epilogue
|
||||
|
||||
---
|
||||
|
||||
## NVFP4 Precision Roadmap (May 23, 2026)
|
||||
|
||||
Three honest buckets. A fourth speculative bucket flagged at the end.
|
||||
|
||||
### NVFP4-0: Verify Right Blackwell FP4 Primitives ⚡ DO FIRST
|
||||
|
||||
**No correctness or quality risk. Pure correctness of implementation.** If these are wrong, we're running wrong MMA shapes silently.
|
||||
|
||||
#### NVFP4-0.1 — sf_dtype tracing
|
||||
|
||||
**What:** Trace the SF dtype through the full pipeline: `gemm_runner.py` → `dense.py` → `blockscaled_utils` → TMEM layout.
|
||||
|
||||
**The problem:** `dense.py` line 137 says NVF4 supports `Float8E8M0FNU/Float8E4M3FN` at sf_vec_size=16. But UE8M0 is the MXFP4/MXFP8 scale format. NVFP4 uses **FP8 E4M3**. The examples on lines 90/100 show `Float8E8M0FNU` at sf_vec_size=16 which is the MXFP4 path. **Need to verify the runner is passing E4M3, not E8M0.**
|
||||
|
||||
Action:
|
||||
- [ ] Print `sf_dtype` in `gemm_runner.py` at construction: `print(f"sf_dtype={sf_dtype}, sf_vec_size={SF_VEC_SIZE}")`
|
||||
- [ ] Print `self.sf_dtype` in `dense.py` `BlockScaledGEMM.__init__`
|
||||
- [ ] Print `self.sf_vec_size` in `dense.py`
|
||||
- [ ] Trace through `blockscaled_utils.make_sm100_sf_layout` — does it produce E4M3 packing (4 FP8 E4M3 → 1 int32) or UE8M0 packing?
|
||||
- [ ] **If wrong sf_dtype is found:** fix in `gemm_runner.py` SF_DTYPE constant, retest MoE cosine
|
||||
|
||||
#### NVFP4-0.2 — SF TMEM layout verification
|
||||
|
||||
**What:** NVFP4 expects scale factors in TMEM in a specific transposed-packed layout. UE4M3 for NVFP4 (4 packed FP8 E4M3 per int32 word). The comment in `dense.py` about "SM100 requires scaling factors in packed UE8M0 format" is for **MXFP8**, not NVFP4.
|
||||
|
||||
Action:
|
||||
- [ ] Print TMEM scale-factor offsets at GEMM construction: `print(f"sf_smem_layout={sf_smem_layout}, sf_tmem_offset={sf_tmem_offset}")`
|
||||
- [ ] Verify the packing matches UE4M3 (NVFP4) not UE8M0 (MXFP8)
|
||||
- [ ] Trace `blockscaled_utils.make_sm100_sf_layout` and print the output layout
|
||||
- [ ] **If wrong packing:** fix `make_sm100_sf_layout` or add NVFP4-specific layout path
|
||||
|
||||
#### NVFP4-0.3 — FP4 TMA element type
|
||||
|
||||
**What:** `float4_e2m1fn_x2` must survive all the way into TMA descriptor creation. Blackstone TMA supports `e2m1_x2` packed-FP4 element type directly. Loading as `uint8` works but loses tensor-core awareness.
|
||||
|
||||
Action:
|
||||
- [ ] Trace `float4_e2m1fn_x2` through `quantize.py` → TMA atom creation in `fmha.py`
|
||||
- [ ] Print the GMEM tensor dtype at FMHA kernel input
|
||||
- [ ] Print the TMA atom dtype at construction
|
||||
- [ ] Verify `cpasync.tma_partition` receives `float4_e2m1fn_x2` element type, not uint8
|
||||
- [ ] **If uint8 fallback:** fix TMA atom creation in `fmha.py`
|
||||
|
||||
#### NVFP4-0.4 — MMA kind is mxf4nvf4
|
||||
|
||||
**What:** Blackwell has a single MMA kind for both MXFP4 and NVFP4. NVFP4 = scales are FP8 E4M3, 16-element block. MXFP4 = scales are UE8M0, 32-element block. The MMA kind is determined by scale-factor type at runtime. Need to confirm tcgen05 is inferring NVFP4.
|
||||
|
||||
Action:
|
||||
- [ ] Print `tcgen05.mma.kind` at GEMM construction (if accessible)
|
||||
- [ ] Print the MMA instruction shape `(M, N, K)` confirmed by JIT compile
|
||||
- [ ] Verify it matches Blackwell MMA shape for NVFP4 (not MXFP4)
|
||||
|
||||
**Execution:** These are 5-minute print jobs. Do all 4 NVFP4-0 items before touching any code. If any of them reveals a wrong dtype, fix it FIRST before anything else. A wrong sf_dtype poisons every FP4 GEMM result.
|
||||
|
||||
---
|
||||
|
||||
### NVFP4-1: Eliminate BF16 Round-Trips After FP4 GEMMs 🔴 PURE-WIN, NO QUALITY RISK
|
||||
|
||||
**These are pure bandwidth/compute wins. The math doesn't change — we just avoid precision loss and kernel launch overhead.**
|
||||
|
||||
#### NVFP4-1.1 — Fuse FP4 quant into SwiGLU epilogue (MoE L1 → L2)
|
||||
|
||||
**What:** Current MoE forward:
|
||||
```
|
||||
padded_x_fp4 → L1 GEMM → SwiGLU → BF16 GMEM ← LEAK
|
||||
quantize_activation_nvfp4 ← SEPARATE KERNEL
|
||||
padded_activated_fp4 → L2 GEMM → BF16 GMEM
|
||||
```
|
||||
|
||||
**Paper §4.2.2:** NVFP4 weights. L1 → SwiGLU → online amax → FP8 scale + FP4 pack → FP4 GMEM → L2 GEMM.
|
||||
|
||||
The amax reduction is in-registers: for an epi tile with 16 contiguous elements per thread, each tile produces one FP8 E4M3 scale and 64 bits of packed FP4 nibbles. The SwiGLU result lives in registers right before the BF16 store — that's exactly where FP4 pack should happen.
|
||||
|
||||
**What you save:**
|
||||
- ~2× GMEM bandwidth between L1 and L2 (FP4 instead of BF16)
|
||||
- Entire `quantize_activation_nvfp4` kernel launch
|
||||
- `padded_activated_fp4` / `padded_activated_x_sf` scratch buffers
|
||||
- GPU-side amax computation (runs on tensor cores vs scalar)
|
||||
- L2 scale-factor TMA reads FP8 scales L1 just produced
|
||||
|
||||
**How:** Extend the fused SwiGLU epilogue. After computing `gate * up`:
|
||||
1. Compute per-16-element amax across the subtile (all-reduce or butterfly shfl_xor)
|
||||
2. Compute FP8 E4M3 scale = amax / 448 (E4M3 max)
|
||||
3. Pack each element: `sign_bit << 7 | (clamped_val / scale).to(uint4)`
|
||||
4. Write packed nibbles to GMEM as `float4_e2m1fn_x2`
|
||||
5. Write FP8 scale to SF TMA buffer
|
||||
|
||||
**The amax subtlety:** For NVFP4 the microblock is 16 elements. Port the same 16-element logic from `quantize.py` into the epilogue. Do NOT use 32-element MXFP4 microblocks.
|
||||
|
||||
- [ ] Extend `dsv4/kernels/gemm/fused_swiglu.py` epilogue to FP4 pack SwiGLU output
|
||||
- [ ] Add per-subtile amax reduction (register-only, no extra kernel)
|
||||
- [ ] Verify: L1 → L2 cosine matches reference (no regression from BF16 intermediate)
|
||||
- [ ] Verify: L2 GEMM reads FP4 scales produced by L1 epilogue
|
||||
- [ ] **Test:** MoE layer output cosine with full L1→L2 pipeline
|
||||
- [ ] **Scope:** MoE-side, NOT fmha.py. Does not block FMHA D1.
|
||||
|
||||
#### NVFP4-1.2 — Fuse FP4 quant into inverse RoPE → wo_a path
|
||||
|
||||
**What:** `inverse_rope_bf16` produces BF16, then `wo_a` quantizes it. Fuse FP4 quant into inverse RoPE epilogue.
|
||||
|
||||
Same pattern as NVFP4-1.1: after inverse RoPE rotation, compute amax → FP8 scale → FP4 pack → FP4 GMEM. The `wo_a` GEMM reads FP4 + scales.
|
||||
|
||||
- [ ] Extend `dsv4/ops/rope.py` inverse RoPE to emit FP4 instead of BF16
|
||||
- [ ] Wire `wo_a` GEMM to read FP4 scales from inverse RoPE output
|
||||
- [ ] **Test:** attention sub-block output cosine (full inverse RoPE → wo_a → attention)
|
||||
|
||||
#### NVFP4-1.3 — Fuse FP4 quant into mHC mixing → attention/FFN input
|
||||
|
||||
**What:** `B_l @ X_l + C_l ⊗ F_out` (mHC post_block) lands in BF16. Attention's `q_down` and FFN's L1 GEMM quantize it. Fuse quant into mHC mixing kernel.
|
||||
|
||||
Same pattern. After mHC mixing post-compute, amax → FP8 scale → FP4 pack → FP4 GMEM. Attention and FFN GEMMs read FP4.
|
||||
|
||||
- [ ] Add FP4 epilogue to mHC mixing kernel (when building it)
|
||||
- [ ] Wire attention `q_down` and FFN L1 to read mHC FP4 output
|
||||
- [ ] **Test:** end-to-end layer cosine (mHC → attention → FFN)
|
||||
|
||||
**Note:** NVFP4-1.2 and NVFP4-1.3 depend on D1.5 (correction epilogue fix) because those epilogues need the clean one-way TMEM path. NVFP4-1.1 (MoE SwiGLU) is independent.
|
||||
|
||||
---
|
||||
|
||||
### NVFP4-2: FP4 KV Pipeline Depth in FMHA 🔴 STAGE D, DEPENDS ON D1
|
||||
|
||||
**FP4 KV shrinks tiles 4×, same SMEM budget buys 3× more pipeline stages.**
|
||||
|
||||
| KV dtype | Tile size (hd=512) | 2 stages | 4 stages | 6 stages |
|
||||
|-----------|--------------------|----------|----------|----------|
|
||||
| BF16 | 128 KB (K+V) | 512 KB ✅ | — | — |
|
||||
| FP8 | 64 KB (K+V) | 256 KB ✅ | 512 KB | — |
|
||||
| FP4 | ~36 KB (K+V) | 144 KB ✅ | 288 KB | 432 KB |
|
||||
|
||||
Each extra stage hides more TMA latency. At 1M-context decode where KV reads dominate, deeper pipelines are a major perf win.
|
||||
|
||||
**Implementation:**
|
||||
- [ ] After D1 (SMEM-P works with BF16): add FP4 TMA load + SMEM dequant path
|
||||
- [ ] TMA loads FP4 NoPE dims (packed e2m1_x2) to SMEM slot 0
|
||||
- [ ] TMA loads BF16 RoPE dims to SMEM slot 1
|
||||
- [ ] TMA loads FP8 scale factors to SMEM slot 2
|
||||
- [ ] Dequantize FP4→BF16 in SMEM (vectorized `* FP8_scale`, 16-element microblocks)
|
||||
- [ ] Concatenate [NoPE, RoPE] in SMEM
|
||||
- [ ] MMA reads contiguous BF16 from SMEM
|
||||
- [ ] **Prerequisite:** D1 working at BF16 first. Cannot skip.
|
||||
- [ ] **Test:** FP4+BF16 split input → identical output to pure BF16 input (dequant is transparent)
|
||||
|
||||
---
|
||||
|
||||
### NVFP4-3: use_2cta_instrs for Production MoE 🟢 30 MINUTES, PURE PERF
|
||||
|
||||
**This is the single biggest single-knob perf win for FP4 GEMMs on B200.**
|
||||
|
||||
**What:** `FusedSwiGLUScaledGroupedGemmKernel` supports 2-CTA UMMA but defaults to `False`. With 2-CTA, the B operand is TMA-multicast: each CTA reads half of B, peers cross the Infiniband link. Effective MMA tile M doubles (128→256, 256→512).
|
||||
|
||||
**Measured win:** 1.7–1.9× throughput over single-CTA at prefill/batch shapes.
|
||||
|
||||
**Decision tree:**
|
||||
- M < 128 (decode single-token): 1-CTA is correct. 2-CTA wastes hardware.
|
||||
- M ≥ 256 (prefill or batched decode): 2-CTA is free perf.
|
||||
- cluster_m must be even for 2-CTA.
|
||||
|
||||
Action:
|
||||
- [ ] Add conditional: `use_2cta_instrs = (M >= 256 and cluster_m % 2 == 0)`
|
||||
- [ ] Default stays `False` (correct for decode)
|
||||
- [ ] Python GEMM runner sets `use_2cta_instrs=True` for prefill shapes
|
||||
- [ ] **Test:** throughput comparison at M=256, 512, 1024
|
||||
- [ ] **Scope:** MoE-side, `gemm_runner.py`. Does not affect FMHA.
|
||||
|
||||
---
|
||||
|
||||
### ⚠️ Speculative: Beyond V4 Paper Validation
|
||||
|
||||
The following are real potential wins but go beyond what the V4 paper explicitly validated for FP4. Listed for completeness, do NOT implement without explicit sign-off from Mike.
|
||||
|
||||
1. **Indexer FP4 tensor-core scoring (paper §5.2.1 "QK path in the indexer... cached, loaded, and multiplied entirely in FP4")**
|
||||
- Paper says the indexer SHOULD do QK in FP4 with tensor cores
|
||||
- Current: scalar FP32 dot products with no tensor cores
|
||||
- Huge scope: 2-3 weeks minimum
|
||||
- **Risk:** FP4 dot product precision for index selection needs recall validation
|
||||
- **Verdict:** Track for Stage F. Do NVFP4-0.4 first.
|
||||
|
||||
2. **MXFP4 vs NVFP4 for indexer scoring** — not validated in the paper for indexer specifically. Evaluate after NVFP4-0.
|
||||
|
||||
3. **NVFP4 for full attention Q×K^T GEMM** — Already closed. NVFP4 Q×K^T is too lossy (cos 0.86 vs FP32). Attention stays FP16/FP32.
|
||||
|
||||
4. **Per-token FP8 activation scaling in FMHA** — Different precision model, not validated. Out of scope.
|
||||
|
||||
---
|
||||
|
||||
## NVFP4 Execution Order
|
||||
|
||||
| # | Task | Scope | Risk | Blocks | Est. |
|
||||
|---|------|-------|------|--------|------|
|
||||
| NVFP4-0.1 | sf_dtype tracing | Both | NONE — print only | D1 if wrong | 5 min |
|
||||
| NVFP4-0.2 | SF TMEM layout | Both | NONE — print only | D1 if wrong | 5 min |
|
||||
| NVFP4-0.3 | FP4 TMA element type | FMHA | NONE — print only | D1 if wrong | 5 min |
|
||||
| NVFP4-0.4 | MMA kind verification | GEMM | NONE — print only | everything | 5 min |
|
||||
| NVFP4-3 | use_2cta_instrs conditional | MoE | NONE — perf only | nothing | 30 min |
|
||||
| NVFP4-1.1 | Fuse FP4 quant into SwiGLU epilogue | MoE | NONE | nothing | 1 day |
|
||||
| NVFP4-1.2 | Fuse FP4 quant into invRoPE→wo_a | Attention | NONE | D1.5 | 1 day |
|
||||
| NVFP4-1.3 | Fuse FP4 quant into mHC mixing | Attention | NONE | post-D5 | 2 days |
|
||||
| D1.5 | Correction epilogue fix | FMHA | MEDIUM | NVFP4-1.2 | 2-3 hours |
|
||||
| NVFP4-2 | FP4 KV pipeline depth | FMHA | NONE — perf only | D1 | 1 day |
|
||||
|
||||
**NVFP4-0 results gate the critical path.** If NVFP4-0.1–0.4 find a wrong sf_dtype, the fix comes before D2. Everything else is either parallel or post-D1.
|
||||
|
||||
**NVFP4-3 (use_2cta_instrs) is the fastest win and has no dependencies.** Do it immediately after the NVFP4-0 prints.
|
||||
|
||||
**NVFP4-1.1 (fuse FP4 into SwiGLU) is the next-biggest win.** No FMHA dependency. Do it in parallel with D2.
|
||||
|
||||
**NVFP4-2 (FP4 KV) depends on D1 being solid.** Do after D2 or alongside hd=512 fix.
|
||||
|
||||
Reference in New Issue
Block a user