This commit is contained in:
2026-06-02 22:31:13 +00:00
parent 9d4a014fad
commit 2eb4f0886e

View File

@@ -32,6 +32,7 @@ NVFP4 KV is correctly ruled out — your own `KVCache` docstring shows 4-bit KV
- Quantize `q` to FP8 before the FMHA (it's BF16 now; see B3). Blackwell FP8 MMA consumes FP8×FP8.
- Wins: removes the per-entry dequant, **halves `gbuf` bandwidth** (the per-step gather is on the decode hot path), and uses FP8 tensor cores. The DeepGEMM reference `fp8_mqa_logits` / FP8 attention paths are the template.
- Gate it behind a cos check vs the BF16 FMHA per layer; if rope-in-FP8 drops cos, keep rope BF16.
- DeepGemm will probably show E4M3 for forward passes and E5M2 for gradients, which is correct
## B2 — Indexer scoring on FP8/FP4 tensor cores (BIG at long context; native FP4)
`single_shot_inference.py` indexer scoring is `torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float())`**full FP32 einsum on CUDA cores over all `n_comp` entries, every CSA layer, every decode step.** At long context this is the dominant indexer cost and it's the *opposite* of native-FP4. The indexer keys are already FP8 in cache. Replace with a tensor-core **weighted-ReLU MQA-logits kernel** in FP8 (or FP4 for the QK path, as the paper does: "lightning indexer ... FP4"). Mirror DeepGEMM `fp8_fp4_mqa_logits`. This is both the long-context perf unlock and a native-FP4 conversion. (The dead `dsv4/kernels/indexer/*.cu` is not this — write it fresh against the DeepGEMM kernel, score in FP8/FP4, top-k with a warp-local reduction, no global lock.)