Files
nvfp4-megamoe-kernel/cutedsl/fp8_bf16.py
biondizzle bbba289bd8 feat: GPU-native SWA + sparse decode attention kernels (CuTeDSL)
- native_swa_decode.py: BlackwellSWADecodeKernel
  - CTA mapping: 1 CTA per (decode_token, q_head_group)
  - Online softmax with KV tile streaming (16 tokens/tile)
  - Pre-dequantized bf16 KV (fp8 dequant on host - MLIR cvt_fpext
    requires 32-bit aligned vector, no scalar fp8->bf16 support)
  - Cosine 0.9999+ vs PyTorch batched SDPA reference
  - Fallback _fallback_batched_sdp when CuTeDSL unavailable

- native_sparse_decode.py: BlackwellSparseDecodeKernel
  - Combined SWA + compressed KV in single attention pass
  - Supports CSA (cr=4) and HCA (cr=128) layers
  - Sink weight merge on host side
  - Cosine 0.9999+ vs combined SDPA reference

- fp8_bf16.py: Documents MLIR limitation (cvt_fpext requires
  vector<4xf8>, no scalar support). Pre-dequant is the workaround.

- vLLM wiring (attention.py):
  - SWA-only layers: native_swa_decode_attention
  - CSA/HCA layers: native_sparse_decode_attention with topk + attn_sink
  - csa_attention.py updated to use native kernels

- Tests: test_decode_pipeline.py, test_sparse_decode.py both passing
2026-05-20 05:46:15 +00:00

27 lines
1.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
FP8 E4M3 -> BF16 conversion for CuTeDSL on Blackwell (SM100+).
STATUS: NOT USABLE INSIDE CUTE KERNELS.
The MLIR nvgpu.cvt_fpext op (which CuTeDSL's .to(BFloat16) generates)
requires a 32-bit aligned 1-d vector operand. Scalar fp8→bf16 conversion
is NOT supported by MLIR. Attempting val_fp8.to(BFloat16) inside a
@cute.kernel produces:
'nvgpu.cvt_fpext' op operand #0 must be 32-bits aligned signless-integer-like
or floating-point-like 1-d vector, but got 'f8E4M3FN'
WORKAROUND: Pre-dequantize fp8→bf16 on the host side before launching the
kernel. This is what native_swa_decode_attention and native_sparse_decode_attention
already do. The cost is negligible:
- Single batched torch op: (fp8.to(bf16) * inv_scale)
- Memory: ~5 MB extra for typical decode batch (32 tokens × 128 window × 512 dim)
- 0.0026% of B200's 192 GB HBM
FUTURE: When CuTeDSL/MLIR adds support for scalar fp8→bf16 conversion,
or when we can properly construct vector<4xf8E4M3FN> inside kernel code,
we can fuse the dequant into the attention kernel. The PTX instruction
exists (cvt.rn.bf16x2.e4m3x2), but CuTeDSL's AST preprocessor currently
prevents us from injecting the necessary MLIR ops.
"""