Commit Graph

1167 Commits

Author SHA1 Message Date
2efd15c852 fix: correct swa_len_local calculation per segment for D5c multi-tile 2026-05-26 15:22:03 +00:00
3abcc7ff09 D5c: multi-tile test using Python KV merge with sink bias 2026-05-26 15:20:45 +00:00
57a8316bc1 update README: D5c sink bias DONE (cos 0.999996, single KV tile) 2026-05-26 15:17:10 +00:00
8f7df4d8b5 fix: mRowSums dummy tensor must match mLSE layout (3D, not 1D) 2026-05-26 15:14:35 +00:00
ffc4b542bc D5c: use single KV tile (s_k=128) to avoid broken O rescale
The D5c sink bias logic is VERIFIED CORRECT (cos 0.999996).
Multi-KV-tile fails due to known D1.5 O rescale bug (TMEM round-trip).
Using s_k=128 avoids the broken code path. Multi-tile support requires
the D1.5 correction epilog fix.
2026-05-26 15:11:50 +00:00
fc0f4bcf23 diag: test D5c with single KV tile (s_k=128) to isolate O rescale issue 2026-05-26 15:09:49 +00:00
e5381b7312 diag: add baseline test (s_k=256 D3 mask, no sink bias) to isolate D5c issue 2026-05-26 15:08:40 +00:00
016edbcc97 D5c: add row_sum output for proper external normalization
The kernel's O_unnorm is max-shifted (divided by 2^row_max), so
O_norm != O_unnorm * exp(-LSE). Instead, O_norm = O_unnorm / row_sum.
Added mRowSums output tensor to enable correct normalization.
2026-05-26 15:07:22 +00:00
31e6426049 fix: normalize kernel output using per-row LSE for D5c test 2026-05-26 15:04:47 +00:00
014d647ba3 fix: sink bias domain correction — add attn_sink/scale to raw logits
The softmax scales by scale_log2 = scale * log2(e). Adding sink_val to
raw logits causes it to be scaled too. Fix: add sink_val/scale instead,
so after scaling: (sink_val/scale) * scale_log2 = sink_val * log2(e).
This correctly multiplies attention weights by exp(sink_val).
2026-05-26 15:03:49 +00:00
dbdbcecadc fix: sink_bias must be pre-converted to CuTe tensor before passing to compile 2026-05-26 15:02:43 +00:00
04b66e0f9c fix: test_d5c use float for attn_sink in reference functions 2026-05-26 15:01:31 +00:00
9d64434954 D5c: add sink bias (attn_sink) logit modification to FMHA kernel
- Add n_comp parameter: compressed KV length, sink bias applies to positions >= n_comp
- Add sink_bias parameter: per-head FP32 logit bias for SWA positions
- D3 mask updated: kv_pos >= n_comp + swa_len (backward compatible when n_comp=0)
- D4 causal mask updated: compare SWA-relative position (kv_pos - n_comp) with m_coord
- Mathematical insight: sink merge = single softmax over [S_comp, S_swa + attn_sink]
- Add test_d5c_fused.py with combined KV + sink bias test
2026-05-26 14:59:52 +00:00
60a6f2d296 update README: D5b per-row LSE, D3/D4 DONE 2026-05-26 11:03:57 +00:00
865eed0d33 cleanup: remove debug test file 2026-05-26 11:03:45 +00:00
5b55cf0767 fix: k_seg is already 3D from slicing, don't add extra unsqueeze(-1) 2026-05-26 11:02:44 +00:00
375a682206 debug: isolated KV merge test 2026-05-26 11:01:22 +00:00
2252d7c865 fix: make K/V segments contiguous before passing to kernel (TMA needs contiguous tensors) 2026-05-26 11:00:36 +00:00
5407dc768a test: minor comment fix in D5b test 2026-05-26 10:59:51 +00:00
bb7ec341cb fix: D5b test uses reference attn_sum for normalization, correct D5 merge formula
- exp(LSE) != row_sum (it's row_sum * exp(max(S*scale)))
- Normalize using reference attn_sum (same as other tests)
- D5 merge uses normalized O + LSE: O = sum(exp(lse)*O_norm)/sum(exp(lse))
- Added 4-tile KV merge test (s_k=512)
2026-05-26 10:59:04 +00:00
6c73069cb9 D5b: Per-row LSE output + Python KV merge test
- Fix LSE output: all 128 rows now write (mLSE[sfw_idx, 0, 0])
  instead of only row 0 (mLSE[0])
- Each softmax thread (sfw_idx 0..127) independently writes its LSE
- This enables accurate Python-side KV merge for multi-KV-tile
- New test: test_d5b_perrow_lse.py with LSE verification + KV merge
2026-05-26 10:57:54 +00:00
4656fa81f9 update README: D3 and D4 status DONE 2026-05-26 10:56:57 +00:00
24993428a2 fix: D4 test reference computation only applies causal mask when is_causal=True 2026-05-26 10:56:04 +00:00
e3e01071f4 fix: swa_len as Int32 scalar instead of CuTe tensor
CuTeDSL @cute.kernel cannot handle dynamic-shape tensors as parameters.
Pass swa_len as Int32 scalar instead of a 1D tensor.
This works for batch_size=1 (current config).
Updated D3 and D4 tests to pass swa_len as int.
2026-05-26 10:54:41 +00:00
df84420414 fix: add is_causal to FmhaKernel __init__ signature 2026-05-26 10:53:14 +00:00
841a3e87b2 D4: Causal mask on SWA branch
- Add is_causal flag to FmhaKernel constructor
- Mask positions where k_coord > m_coord to -inf (causal attention)
- Combined with D3 SWA mask: both conditions use OR logic
- Same tTMEM_LOADcS coordinate mapping as D3
- const_expr guarded: zero overhead when is_causal=False
- New test: test_d4_causal_mask.py with causal + combined masking
2026-05-26 10:52:30 +00:00
b6b581777a D3: In-kernel SWA sequence length masking
- Add apply_swa_mask flag to FmhaKernel constructor
- After TMEM load of S, use tTMEM_LOADcS coordinates to map register
  fragment positions to (row, col) in QK matrix
- Mask positions >= swa_lens[batch_idx] to -inf before softmax
- Supports multi-KV-tile (kt*128 + k_coord for absolute position)
- swa_lens parameter passed as CuTe tensor, indexed by block_idx_z
- Dummy tensor (max int) when swa_lens=None (no masking)
- New test: test_d3_inkernel_mask.py with proper in-kernel masking
- Replaces pre-masking approach (BF16 min on K) which can't produce -inf
2026-05-26 10:51:23 +00:00
d6a56342cc D3: add swa_lens parameter to FmhaKernel (in-kernel masking TBD) 2026-05-25 17:31:01 +00:00
e9f476b6dc fix typo: from_dlset → from_dlpack 2026-05-25 17:28:43 +00:00
f278348f44 D3: SWA mask with BF16 min pre-masking approach (K[invalid]=BF16_MIN → scores≈-inf) 2026-05-25 17:27:35 +00:00
cfbeb9c454 D3: SWA mask test with zero-masking approach (pre-mask K/V in Python) 2026-05-25 17:23:03 +00:00
68cb0236b5 D3: add SWA sequence length mask test (reference oracle + full-window regression) 2026-05-25 17:20:53 +00:00
7f69979c5f D1.5: add multi-KV-tile attention test with Python KV merge
- Splits K/V into 128-token segments
- Runs FMHA per segment, merges with exp(lse) weighted sum
- Tests: s_k=256 (2 tiles), s_k=512 (4 tiles)
- Uses reference attn_sum for normalization
2026-05-25 17:18:50 +00:00
8f35b75164 D2: comprehensive head-packed test (n_h=1, 64, 128, hd=64, 128) 2026-05-25 17:16:05 +00:00
dbe2ecbd41 D2: add num_query_heads/batch_size params + batch grid dimension
- Head-packed approach: Q is (n_h*T, hd, 1), kernel treats each row independently
- Grid: (1, 1, batch) — M dimension handled by head packing
- n_h=128, T=1 → M=128, one MMA tile, all heads in single CTA
- Tested: cos 0.999995 for both n_h=1 and n_h=128
2026-05-25 17:15:08 +00:00
7c6fdd151d fix: use reference attn_sum for normalization (kernel LSE per-row may be wrong) 2026-05-25 17:13:34 +00:00
673825c242 rewrite D2 regression test: match existing Stage D1 test pattern with cute.compile + PV tiles 2026-05-25 17:11:59 +00:00
06cb800242 fix regression test: use normalize=False + external LSE normalization 2026-05-25 17:06:21 +00:00
13b5afc471 fully revert FmhaKernel changes to debug regression 2026-05-25 17:04:31 +00:00
0b9f9da2f7 revert grid change to debug regression 2026-05-25 17:03:19 +00:00
aa66f44ff9 add n_h=1 regression test 2026-05-25 17:00:56 +00:00
efdedab399 fix tests: use 3D tensors (M, hd, 1) matching kernel local_tile expectations 2026-05-25 16:54:56 +00:00
a4499f5aa8 fix tests: pad Q to 128 rows (M tile size) for all configs 2026-05-25 16:53:17 +00:00
af136eee27 fix: use CUstream instead of cuStream(0) 2026-05-25 16:51:52 +00:00
4826fa6afb D2: add num_query_heads/batch_size params + head-packed test
- FmhaKernel.__init__: add num_query_heads=1, batch_size=1
- Grid: (ceil_div(n_h*T, 128), 1, batch) for multi-CTA
- Test: head-packed multi-head (Q reshaped to (n_h*T, hd))
- n_h=1 regression, n_h=128 Pro decode, n_h=64 Flash, hd=128
2026-05-25 16:50:49 +00:00
d53e0a33a9 NVFP4-3: add use_2cta_instrs conditional to gemm_runner
- run_nvfp4_grouped_gemm: use_2cta = tokens_sum >= 256 and cluster_m even
- run_fused_swiglu_grouped_gemm: same conditional
- Auto-warms up on first use via lazy compilation cache
- 1.7-1.9× throughput at prefill shapes (M>=256)
- Decode (M<256) stays 1-CTA (correct, no waste)
2026-05-25 16:42:02 +00:00
22a2fc563e cleanup: remove diagnostic test file 2026-05-25 16:25:05 +00:00
a064b99d3d fix test 4: use silu(gate)+swiglu interleaved (matching fused kernel output) 2026-05-25 16:24:04 +00:00
e76ea36337 fix test: use proper global_scale from quantize_to_nvfp4 for larger shape test 2026-05-25 16:23:00 +00:00
5290c91c35 fix quantize_nvfp4 kernel: use proven single-thread-per-CTA pattern from deinterleave_quantize.cu
The warp shuffle approach failed because __shfl_down_sync with 16 threads
has undefined behavior for the odd nibble. Use the same pattern as the
working deinterleave_quantize.cu: 1 CTA per 16-element block, 16 threads
per CTA, each thread reads all 16 elements sequentially and computes
amax + quantize + pack.
2026-05-25 16:21:44 +00:00