- gather_compressed_kv: CSA top-k gather via existing gather_kv.cu - gather_all_compressed_kv: HCA dense gather via new gather_all_compressed_kernel - gather_swa_kv: SWA ring buffer gather via new gather_swa_kernel - Added gather_swa.cu with both SWA + all-compressed gather kernels - Added gather.py Python wrapper (torch.utils.cpp_extension JIT) - Updated handle.py: added schema field, num_query_heads/head_dim properties - Updated manager.py: passes schema + num_query_heads to handle All gather kernels: FP8→BF16 dequant + BF16 RoPE concat in single launch. Output: dense BF16 tensors ready for FMHA consumption.
8.5 KiB
8.5 KiB