- native_swa_decode.py: BlackwellSWADecodeKernel
- CTA mapping: 1 CTA per (decode_token, q_head_group)
- Online softmax with KV tile streaming (16 tokens/tile)
- Pre-dequantized bf16 KV (fp8 dequant on host - MLIR cvt_fpext
requires 32-bit aligned vector, no scalar fp8->bf16 support)
- Cosine 0.9999+ vs PyTorch batched SDPA reference
- Fallback _fallback_batched_sdp when CuTeDSL unavailable
- native_sparse_decode.py: BlackwellSparseDecodeKernel
- Combined SWA + compressed KV in single attention pass
- Supports CSA (cr=4) and HCA (cr=128) layers
- Sink weight merge on host side
- Cosine 0.9999+ vs combined SDPA reference
- fp8_bf16.py: Documents MLIR limitation (cvt_fpext requires
vector<4xf8>, no scalar support). Pre-dequant is the workaround.
- vLLM wiring (attention.py):
- SWA-only layers: native_swa_decode_attention
- CSA/HCA layers: native_sparse_decode_attention with topk + attn_sink
- csa_attention.py updated to use native kernels
- Tests: test_decode_pipeline.py, test_sparse_decode.py both passing
27 lines
1.2 KiB
Python
27 lines
1.2 KiB
Python
"""
|
||
FP8 E4M3 -> BF16 conversion for CuTeDSL on Blackwell (SM100+).
|
||
|
||
STATUS: NOT USABLE INSIDE CUTE KERNELS.
|
||
|
||
The MLIR nvgpu.cvt_fpext op (which CuTeDSL's .to(BFloat16) generates)
|
||
requires a 32-bit aligned 1-d vector operand. Scalar fp8→bf16 conversion
|
||
is NOT supported by MLIR. Attempting val_fp8.to(BFloat16) inside a
|
||
@cute.kernel produces:
|
||
|
||
'nvgpu.cvt_fpext' op operand #0 must be 32-bits aligned signless-integer-like
|
||
or floating-point-like 1-d vector, but got 'f8E4M3FN'
|
||
|
||
WORKAROUND: Pre-dequantize fp8→bf16 on the host side before launching the
|
||
kernel. This is what native_swa_decode_attention and native_sparse_decode_attention
|
||
already do. The cost is negligible:
|
||
- Single batched torch op: (fp8.to(bf16) * inv_scale)
|
||
- Memory: ~5 MB extra for typical decode batch (32 tokens × 128 window × 512 dim)
|
||
- 0.0026% of B200's 192 GB HBM
|
||
|
||
FUTURE: When CuTeDSL/MLIR adds support for scalar fp8→bf16 conversion,
|
||
or when we can properly construct vector<4xf8E4M3FN> inside kernel code,
|
||
we can fuse the dequant into the attention kernel. The PTX instruction
|
||
exists (cvt.rn.bf16x2.e4m3x2), but CuTeDSL's AST preprocessor currently
|
||
prevents us from injecting the necessary MLIR ops.
|
||
"""
|