things
This commit is contained in:
@@ -32,6 +32,7 @@ NVFP4 KV is correctly ruled out — your own `KVCache` docstring shows 4-bit KV
|
||||
- Quantize `q` to FP8 before the FMHA (it's BF16 now; see B3). Blackwell FP8 MMA consumes FP8×FP8.
|
||||
- Wins: removes the per-entry dequant, **halves `gbuf` bandwidth** (the per-step gather is on the decode hot path), and uses FP8 tensor cores. The DeepGEMM reference `fp8_mqa_logits` / FP8 attention paths are the template.
|
||||
- Gate it behind a cos check vs the BF16 FMHA per layer; if rope-in-FP8 drops cos, keep rope BF16.
|
||||
- DeepGemm will probably show E4M3 for forward passes and E5M2 for gradients, which is correct
|
||||
|
||||
## B2 — Indexer scoring on FP8/FP4 tensor cores (BIG at long context; native FP4)
|
||||
`single_shot_inference.py` indexer scoring is `torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float())` → **full FP32 einsum on CUDA cores over all `n_comp` entries, every CSA layer, every decode step.** At long context this is the dominant indexer cost and it's the *opposite* of native-FP4. The indexer keys are already FP8 in cache. Replace with a tensor-core **weighted-ReLU MQA-logits kernel** in FP8 (or FP4 for the QK path, as the paper does: "lightning indexer ... FP4"). Mirror DeepGEMM `fp8_fp4_mqa_logits`. This is both the long-context perf unlock and a native-FP4 conversion. (The dead `dsv4/kernels/indexer/*.cu` is not this — write it fresh against the DeepGEMM kernel, score in FP8/FP4, top-k with a warp-local reduction, no global lock.)
|
||||
|
||||
Reference in New Issue
Block a user