biondizzle
dd7af0cd8a
feat: GPU-native SWA + sparse decode attention kernels (CuTeDSL)
- native_swa_decode.py: BlackwellSWADecodeKernel
- CTA mapping: 1 CTA per (decode_token, q_head_group)
- Online softmax with KV tile streaming (16 tokens/tile)
- Pre-dequantized bf16 KV (fp8 dequant on host - MLIR cvt_fpext
requires 32-bit aligned vector, no scalar fp8->bf16 support)
- Cosine 0.9999+ vs PyTorch batched SDPA reference
- Fallback _fallback_batched_sdp when CuTeDSL unavailable
- native_sparse_decode.py: BlackwellSparseDecodeKernel
- Combined SWA + compressed KV in single attention pass
- Supports CSA (cr=4) and HCA (cr=128) layers
- Sink weight merge on host side
- Cosine 0.9999+ vs combined SDPA reference
- fp8_bf16.py: Documents MLIR limitation (cvt_fpext requires
vector<4xf8>, no scalar support). Pre-dequant is the workaround.
- vLLM wiring (attention.py):
- SWA-only layers: native_swa_decode_attention
- CSA/HCA layers: native_sparse_decode_attention with topk + attn_sink
- csa_attention.py updated to use native kernels
- Tests: test_decode_pipeline.py, test_sparse_decode.py both passing
2026-05-20 05:46:15 +00:00
..
2026-05-20 04:13:52 +00:00
2026-05-20 04:39:47 +00:00
2026-05-16 02:57:54 +00:00
2026-05-19 15:30:29 +00:00
2026-05-20 04:39:47 +00:00
2026-05-19 08:01:31 +00:00
2026-05-19 01:54:48 +00:00
2026-05-20 05:46:15 +00:00
2026-05-19 02:36:30 +00:00
2026-05-20 04:39:47 +00:00
2026-05-20 05:46:15 +00:00
2026-05-20 05:46:15 +00:00
2026-05-20 02:08:01 +00:00
2026-05-20 04:13:52 +00:00
2026-05-19 07:18:10 +00:00
2026-05-19 02:45:57 +00:00