Commit Graph

554 Commits

Author SHA1 Message Date
3a25c7feff Test multi-KV merge (2 segments) separately from multi-head 2026-05-27 06:54:16 +00:00
e45b94c01b Test: compare both normalized and un-normalized reference 2026-05-27 06:44:37 +00:00
98c93c1cd8 Stage E: production attention wrapper + Python KV merge, clean fmha_smem_acc 2026-05-27 06:34:10 +00:00
b02e103ac0 Add c_simple GMEM tensor (non-dynamic) for SMEM accumulator TMA store 2026-05-27 05:33:30 +00:00
bf36979a8d Use CUTLASS FMHA reference pattern for sC->GMEM TMA store (flat_divide + tma_partition) 2026-05-27 05:24:39 +00:00
97bc6d8d2f Add c_direct GMEM tensor for direct writes in SMEM accumulator path 2026-05-27 05:15:47 +00:00
a858ed1c14 Fix test: normalize=False for un-normalized O comparison 2026-05-27 05:06:52 +00:00
3a7d87adba Fix test_smem_acc: use keyword args for lse/row_sums 2026-05-27 04:54:23 +00:00
6a621bdf64 D1.5: SMEM accumulator FMHA kernel — one-way TMEM→REGS→SMEM, no round-trip
TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY BROKEN.
Even NO-OP (multiply by 1.0) corrupts data.

New approach:
- PV always ACCUMULATE=False (fresh TMEM each kt)
- After pv_done_bar: one-way Ld32x32bOp load O_kt from TMEM→REGS
- Coordinate-indexed SMEM accumulation: sO_acc = acc_scale * sO_acc + O_kt
- sO_acc: FP32 [128, pv_n_tile] row-major (32KB at hd=64, 64KB at hd=128)
- Final: normalize, cast BF16, write to sC, TMA store to GMEM
2026-05-27 04:53:40 +00:00
42c5793add D1.5: Add isolated round-trip test comparing s_k=128 vs s_k=256 with NOOP rescale 2026-05-26 20:45:58 +00:00
3be708d923 D1.5 debug: add NOOP rescale test (acc_scale=1.0) to isolate TMEM round-trip corruption 2026-05-26 20:28:55 +00:00
c3648e4ebf D1.5 debug: add targeted s_k=256 rescale diagnostic test 2026-05-26 20:27:37 +00:00
bf2c7c8bb8 D1.5: Implement in-kernel O rescale via CUTLASS correction_rescale pattern
- Both load and store atoms built from SAME tOtO_i (composition-tiled)
- Same Repetition(corr_tile_size=16) for both copies
- pv_done_bar synchronization between MMA and softmax warps
- acc_scale computed per kt iteration, used to rescale O in TMEM
- const_expr(n_kv_tiles > 1) guards for zero overhead at s_k=128
- New test: test_d15_in_kernel_rescale.py (s_k=128/256/384)
- Minimal roundtrip test: test_tmem_roundtrip_minimal.py
2026-05-26 20:26:06 +00:00
2b4f4ce538 Remove broken D1.5 paired-atom test (TMEM round-trick is fundamentally broken) 2026-05-26 19:50:31 +00:00
40cbf0c223 Add D1.5 paired-atom O rescale test (s_k=256/384, hd=64/128) 2026-05-26 19:46:19 +00:00
487d960a6a D5c multi-tile: VERIFIED cos 0.999996 with Python KV merge + sink bias
Both segments (compressed+SWA with n_comp=96, and SWA-only with n_comp=0)
pass individually at cos 0.999996. The Python KV merge produces the
correct combined attention at cos 0.999996.

Key: n_comp is compile-time, so separate kernel instances are needed
for segments with different n_comp values. Production code would use
a kernel cache keyed on (n_comp, apply_sink_bias, ...).
2026-05-26 15:40:45 +00:00
c9eab3c7e0 diag: rewrite multi-tile test with explicit per-segment compile and reference 2026-05-26 15:39:39 +00:00
7f983fb855 diag: add direct segment 0 test to compare with run_segment 2026-05-26 15:37:06 +00:00
2a5f9dc6e3 fix: simplify run_segment to use full hd V tensor (was incorrectly splitting by pv_n_tile) 2026-05-26 15:34:57 +00:00
aa2df1a202 diag: test n_comp=96 with sink bias directly 2026-05-26 15:33:38 +00:00
25b236fe00 diag: test D5c multi-tile with no sink bias to isolate issue 2026-05-26 15:31:38 +00:00
a3989929de diag: per-segment reference comparison for D5c multi-tile 2026-05-26 15:29:52 +00:00
bbc29945e8 diag: add per-segment debug output for D5c multi-tile 2026-05-26 15:28:19 +00:00
e64392f1ac D5c: add apply_sink_bias flag (independent of n_comp)
For all-SWA segments (n_comp=0), sink bias still needs to be applied
to all positions. The apply_sink_bias flag controls compilation of
the sink bias code path, independent of n_comp offset.
2026-05-26 15:26:52 +00:00
60b6f8ee79 fix: use separate kernels for segments with/without compressed KV (n_comp is compile-time) 2026-05-26 15:23:21 +00:00
2efd15c852 fix: correct swa_len_local calculation per segment for D5c multi-tile 2026-05-26 15:22:03 +00:00
3abcc7ff09 D5c: multi-tile test using Python KV merge with sink bias 2026-05-26 15:20:45 +00:00
ffc4b542bc D5c: use single KV tile (s_k=128) to avoid broken O rescale
The D5c sink bias logic is VERIFIED CORRECT (cos 0.999996).
Multi-KV-tile fails due to known D1.5 O rescale bug (TMEM round-trip).
Using s_k=128 avoids the broken code path. Multi-tile support requires
the D1.5 correction epilog fix.
2026-05-26 15:11:50 +00:00
fc0f4bcf23 diag: test D5c with single KV tile (s_k=128) to isolate O rescale issue 2026-05-26 15:09:49 +00:00
e5381b7312 diag: add baseline test (s_k=256 D3 mask, no sink bias) to isolate D5c issue 2026-05-26 15:08:40 +00:00
016edbcc97 D5c: add row_sum output for proper external normalization
The kernel's O_unnorm is max-shifted (divided by 2^row_max), so
O_norm != O_unnorm * exp(-LSE). Instead, O_norm = O_unnorm / row_sum.
Added mRowSums output tensor to enable correct normalization.
2026-05-26 15:07:22 +00:00
31e6426049 fix: normalize kernel output using per-row LSE for D5c test 2026-05-26 15:04:47 +00:00
dbdbcecadc fix: sink_bias must be pre-converted to CuTe tensor before passing to compile 2026-05-26 15:02:43 +00:00
04b66e0f9c fix: test_d5c use float for attn_sink in reference functions 2026-05-26 15:01:31 +00:00
9d64434954 D5c: add sink bias (attn_sink) logit modification to FMHA kernel
- Add n_comp parameter: compressed KV length, sink bias applies to positions >= n_comp
- Add sink_bias parameter: per-head FP32 logit bias for SWA positions
- D3 mask updated: kv_pos >= n_comp + swa_len (backward compatible when n_comp=0)
- D4 causal mask updated: compare SWA-relative position (kv_pos - n_comp) with m_coord
- Mathematical insight: sink merge = single softmax over [S_comp, S_swa + attn_sink]
- Add test_d5c_fused.py with combined KV + sink bias test
2026-05-26 14:59:52 +00:00
865eed0d33 cleanup: remove debug test file 2026-05-26 11:03:45 +00:00
5b55cf0767 fix: k_seg is already 3D from slicing, don't add extra unsqueeze(-1) 2026-05-26 11:02:44 +00:00
375a682206 debug: isolated KV merge test 2026-05-26 11:01:22 +00:00
2252d7c865 fix: make K/V segments contiguous before passing to kernel (TMA needs contiguous tensors) 2026-05-26 11:00:36 +00:00
5407dc768a test: minor comment fix in D5b test 2026-05-26 10:59:51 +00:00
bb7ec341cb fix: D5b test uses reference attn_sum for normalization, correct D5 merge formula
- exp(LSE) != row_sum (it's row_sum * exp(max(S*scale)))
- Normalize using reference attn_sum (same as other tests)
- D5 merge uses normalized O + LSE: O = sum(exp(lse)*O_norm)/sum(exp(lse))
- Added 4-tile KV merge test (s_k=512)
2026-05-26 10:59:04 +00:00
6c73069cb9 D5b: Per-row LSE output + Python KV merge test
- Fix LSE output: all 128 rows now write (mLSE[sfw_idx, 0, 0])
  instead of only row 0 (mLSE[0])
- Each softmax thread (sfw_idx 0..127) independently writes its LSE
- This enables accurate Python-side KV merge for multi-KV-tile
- New test: test_d5b_perrow_lse.py with LSE verification + KV merge
2026-05-26 10:57:54 +00:00
24993428a2 fix: D4 test reference computation only applies causal mask when is_causal=True 2026-05-26 10:56:04 +00:00
e3e01071f4 fix: swa_len as Int32 scalar instead of CuTe tensor
CuTeDSL @cute.kernel cannot handle dynamic-shape tensors as parameters.
Pass swa_len as Int32 scalar instead of a 1D tensor.
This works for batch_size=1 (current config).
Updated D3 and D4 tests to pass swa_len as int.
2026-05-26 10:54:41 +00:00
841a3e87b2 D4: Causal mask on SWA branch
- Add is_causal flag to FmhaKernel constructor
- Mask positions where k_coord > m_coord to -inf (causal attention)
- Combined with D3 SWA mask: both conditions use OR logic
- Same tTMEM_LOADcS coordinate mapping as D3
- const_expr guarded: zero overhead when is_causal=False
- New test: test_d4_causal_mask.py with causal + combined masking
2026-05-26 10:52:30 +00:00
b6b581777a D3: In-kernel SWA sequence length masking
- Add apply_swa_mask flag to FmhaKernel constructor
- After TMEM load of S, use tTMEM_LOADcS coordinates to map register
  fragment positions to (row, col) in QK matrix
- Mask positions >= swa_lens[batch_idx] to -inf before softmax
- Supports multi-KV-tile (kt*128 + k_coord for absolute position)
- swa_lens parameter passed as CuTe tensor, indexed by block_idx_z
- Dummy tensor (max int) when swa_lens=None (no masking)
- New test: test_d3_inkernel_mask.py with proper in-kernel masking
- Replaces pre-masking approach (BF16 min on K) which can't produce -inf
2026-05-26 10:51:23 +00:00
e9f476b6dc fix typo: from_dlset → from_dlpack 2026-05-25 17:28:43 +00:00
f278348f44 D3: SWA mask with BF16 min pre-masking approach (K[invalid]=BF16_MIN → scores≈-inf) 2026-05-25 17:27:35 +00:00
cfbeb9c454 D3: SWA mask test with zero-masking approach (pre-mask K/V in Python) 2026-05-25 17:23:03 +00:00
68cb0236b5 D3: add SWA sequence length mask test (reference oracle + full-window regression) 2026-05-25 17:20:53 +00:00