Files
nvfp4-megamoe-kernel/dsv4/kernels/decode/_NOTES_fp8_bf16.md
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

1.2 KiB
Raw Blame History

FP8 E4M3 -> BF16 conversion for CuTeDSL on Blackwell (SM100+).

STATUS: NOT USABLE INSIDE CUTE KERNELS.

The MLIR nvgpu.cvt_fpext op (which CuTeDSL's .to(BFloat16) generates) requires a 32-bit aligned 1-d vector operand. Scalar fp8→bf16 conversion is NOT supported by MLIR. Attempting val_fp8.to(BFloat16) inside a @cute.kernel produces:

'nvgpu.cvt_fpext' op operand #0 must be 32-bits aligned signless-integer-like or floating-point-like 1-d vector, but got 'f8E4M3FN'

WORKAROUND: Pre-dequantize fp8→bf16 on the host side before launching the kernel. This is what native_swa_decode_attention and native_sparse_decode_attention already do. The cost is negligible:

  • Single batched torch op: (fp8.to(bf16) * inv_scale)
  • Memory: ~5 MB extra for typical decode batch (32 tokens × 128 window × 512 dim)
  • 0.0026% of B200's 192 GB HBM

FUTURE: When CuTeDSL/MLIR adds support for scalar fp8→bf16 conversion, or when we can properly construct vector<4xf8E4M3FN> inside kernel code, we can fuse the dequant into the attention kernel. The PTX instruction exists (cvt.rn.bf16x2.e4m3x2), but CuTeDSL's AST preprocessor currently prevents us from injecting the necessary MLIR ops.