Files
nvfp4-megamoe-kernel/dsv4/kernels/decode/_NOTES_fp8_bf16.md

27 lines
1.2 KiB
Markdown
Raw Normal View History

FP8 E4M3 -> BF16 conversion for CuTeDSL on Blackwell (SM100+).
STATUS: NOT USABLE INSIDE CUTE KERNELS.
The MLIR nvgpu.cvt_fpext op (which CuTeDSL's .to(BFloat16) generates)
requires a 32-bit aligned 1-d vector operand. Scalar fp8→bf16 conversion
is NOT supported by MLIR. Attempting val_fp8.to(BFloat16) inside a
@cute.kernel produces:
'nvgpu.cvt_fpext' op operand #0 must be 32-bits aligned signless-integer-like
or floating-point-like 1-d vector, but got 'f8E4M3FN'
WORKAROUND: Pre-dequantize fp8→bf16 on the host side before launching the
kernel. This is what native_swa_decode_attention and native_sparse_decode_attention
already do. The cost is negligible:
- Single batched torch op: (fp8.to(bf16) * inv_scale)
- Memory: ~5 MB extra for typical decode batch (32 tokens × 128 window × 512 dim)
- 0.0026% of B200's 192 GB HBM
FUTURE: When CuTeDSL/MLIR adds support for scalar fp8→bf16 conversion,
or when we can properly construct vector<4xf8E4M3FN> inside kernel code,
we can fuse the dequant into the attention kernel. The PTX instruction
exists (cvt.rn.bf16x2.e4m3x2), but CuTeDSL's AST preprocessor currently
prevents us from injecting the necessary MLIR ops.