Commit Graph

5 Commits

Author SHA1 Message Date
4826fa6afb D2: add num_query_heads/batch_size params + head-packed test
- FmhaKernel.__init__: add num_query_heads=1, batch_size=1
- Grid: (ceil_div(n_h*T, 128), 1, batch) for multi-CTA
- Test: head-packed multi-head (Q reshaped to (n_h*T, hd))
- n_h=1 regression, n_h=128 Pro decode, n_h=64 Flash, hd=128
2026-05-25 16:50:49 +00:00
29ad36934d cleanup: remove D2 diagnostic/experimental files, keep working codebase clean 2026-05-25 02:40:12 +00:00
7599801f57 D2: add flat_divide shape diagnostic kernel for multi-CTA grid 2026-05-25 02:33:15 +00:00
6cc151097e Revert D2 multi-CTA attempts - keeping per-head launch approach (works correctly) 2026-05-25 01:08:38 +00:00
4c79e5533e D2: add multi-CTA grid with block_idx_y for Q/O head indexing 2026-05-24 23:27:38 +00:00