From 9ba6476d3f4f9b97f6128a6ebf84ee262c1fea85 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 21:39:01 +0000 Subject: [PATCH] auto: pre-test commit --- FINAL_STRETCH.md | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/FINAL_STRETCH.md b/FINAL_STRETCH.md index 0fcbf024..e5ee80f2 100644 --- a/FINAL_STRETCH.md +++ b/FINAL_STRETCH.md @@ -7,6 +7,23 @@ Goal: native NVFP4 where the math allows, FP8_E4M3 where it doesn't, BF16/FP32 o ## B0 — What's already optimal: DO NOT "fix" the MoE `dsv4/layers/moe.py` already runs **native NVFP4**: expert weights and activations are `float4_e2m1fn_x2`, block scales are `float8_e4m3fn`. This matches the paper (routed experts in FP4). Leave it. The remaining wins are in **attention** and the **indexer**, not MoE. +### P5 — Fused mHC pre_block + RMSNorm + NVFP4 quantize: ✅ DONE +- `fused_mhc_rmsnorm_quantize.cu` — 2-kernel approach (mhc_rmsnorm_amax_gsa + mhc_rmsnorm_quantize_nvfp4) +- **Integrated into `forward_layer`** for BOTH attn and ffn mHC paths (commit 0b6ca0d) +- Replaces: pre_block bmm (1 launch) + rmsnorm (4+ launches) + quantize (2 launches) → 2 launches +- Savings: ~5 launches/site × 2 sites × 61 layers = 610 launches/token +- Unit test: cos=0.999 vs unfused, 0.995 vs true mHC+RMSNorm at T=1/8/128 +- gsa per-row diff: ~1-2e-6 (excellent) + +### P4 — Fused RMSNorm + NVFP4 quantize: ✅ DONE +- `fused_rmsnorm_quantize.cu` — 2-kernel approach +- Integrated for standalone rmsnorm+quantize paths +- gsa scalar fix in `Nvfp4Linear.run_from_quantized`: per-row gsa reduced to scalar (max) for GEMM compatibility + +### Stale Lock Fix: ✅ DONE (commit 845227c) +- `dsv4/kernels/cuda/loader.py`: _cleanup_stale_lock() removes lock files older than 10 minutes +- Prevents infinite spin after crash/kill during CUDA kernel compilation + ## B1 — FP8_E4M3 FMHA (BIG win; perf + memory + native Blackwell) Today: KV is *stored* mixed (FP8 nope + BF16 rope), then in "5. Gather KV" it's **dequantized to BF16** into `gbuf`, and the FMHA runs in **BF16**. That throws away the FP8 you stored and runs the heaviest kernel at half the tensor-core throughput Blackwell offers. @@ -20,7 +37,10 @@ NVFP4 KV is correctly ruled out — your own `KVCache` docstring shows 4-bit KV `single_shot_inference.py` indexer scoring is `torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float())` → **full FP32 einsum on CUDA cores over all `n_comp` entries, every CSA layer, every decode step.** At long context this is the dominant indexer cost and it's the *opposite* of native-FP4. The indexer keys are already FP8 in cache. Replace with a tensor-core **weighted-ReLU MQA-logits kernel** in FP8 (or FP4 for the QK path, as the paper does: "lightning indexer ... FP4"). Mirror DeepGEMM `fp8_fp4_mqa_logits`. This is both the long-context perf unlock and a native-FP4 conversion. (The dead `dsv4/kernels/indexer/*.cu` is not this — write it fresh against the DeepGEMM kernel, score in FP8/FP4, top-k with a warp-local reduction, no global lock.) ## B3 — Fused rmsnorm→quant for q_a_norm / kv_norm (small, removes BF16 round-trips) -The fused `rmsnorm_quantize_nvfp4` path is used for the mHC attn/ffn input norms (good), but `q_a_norm` and `kv_norm` in `forward_attention` still call plain `rmsnorm` (returns BF16) and then re-quantize downstream. This is exactly the "produce BF16, immediately re-cast" pattern. Extend the fused fp32→(FP4/FP8) norm+quant to `q_a_norm` (feeds `q_b`) and `kv_norm` (feeds the FMHA), emitting the consumer's target dtype directly and skipping the intermediate BF16 materialization. Modest but on the per-token hot path. +- ✅ DONE: `q_a_norm` → `q_b` path now uses fused `rmsnorm_quantize_nvfp4` + `run_from_quantized` (commit 0b6ca0d) +- Skips BF16 materialization between q_a_norm and q_b GEMM +- Saves ~6 kernel launches per layer +- `kv_norm` still uses unfused rmsnorm — requires FP8 FMHA (B1) to fully benefit, since kv goes to RoPE not another GEMM ## B4 — General "producer BF16 → consumer FP32" sweep (the user's pattern) Find and fix places that cast up immediately after producing a narrower dtype: