2026-05-21 17:30:44 +00:00
|
|
|
|
|
2026-05-20 05:46:15 +00:00
|
|
|
|
FP8 E4M3 -> BF16 conversion for CuTeDSL on Blackwell (SM100+).
|
|
|
|
|
|
|
|
|
|
|
|
STATUS: NOT USABLE INSIDE CUTE KERNELS.
|
|
|
|
|
|
|
|
|
|
|
|
The MLIR nvgpu.cvt_fpext op (which CuTeDSL's .to(BFloat16) generates)
|
|
|
|
|
|
requires a 32-bit aligned 1-d vector operand. Scalar fp8→bf16 conversion
|
|
|
|
|
|
is NOT supported by MLIR. Attempting val_fp8.to(BFloat16) inside a
|
|
|
|
|
|
@cute.kernel produces:
|
|
|
|
|
|
|
|
|
|
|
|
'nvgpu.cvt_fpext' op operand #0 must be 32-bits aligned signless-integer-like
|
|
|
|
|
|
or floating-point-like 1-d vector, but got 'f8E4M3FN'
|
|
|
|
|
|
|
|
|
|
|
|
WORKAROUND: Pre-dequantize fp8→bf16 on the host side before launching the
|
|
|
|
|
|
kernel. This is what native_swa_decode_attention and native_sparse_decode_attention
|
|
|
|
|
|
already do. The cost is negligible:
|
|
|
|
|
|
- Single batched torch op: (fp8.to(bf16) * inv_scale)
|
|
|
|
|
|
- Memory: ~5 MB extra for typical decode batch (32 tokens × 128 window × 512 dim)
|
|
|
|
|
|
- 0.0026% of B200's 192 GB HBM
|
|
|
|
|
|
|
|
|
|
|
|
FUTURE: When CuTeDSL/MLIR adds support for scalar fp8→bf16 conversion,
|
|
|
|
|
|
or when we can properly construct vector<4xf8E4M3FN> inside kernel code,
|
|
|
|
|
|
we can fuse the dequant into the attention kernel. The PTX instruction
|
|
|
|
|
|
exists (cvt.rn.bf16x2.e4m3x2), but CuTeDSL's AST preprocessor currently
|
|
|
|
|
|
prevents us from injecting the necessary MLIR ops.
|
2026-05-21 17:30:44 +00:00
|
|
|
|
|