Commit Graph

1184 Commits

Author SHA1 Message Date
064ececc9a Update docs: D1.5 TMEM round-trip fundamentally broken, Python KV merge is production path 2026-05-26 19:53:10 +00:00
2b4f4ce538 Remove broken D1.5 paired-atom test (TMEM round-trick is fundamentally broken) 2026-05-26 19:50:31 +00:00
ffb3e736bb D1.5: Revert broken paired-atom O rescale — TMEM round-trip fundamentally broken
Ld32x32bOp and St32x32bOp have different column mappings at the hardware
level. No layout transformation can fix this — the atoms themselves map
TMEM columns differently.

The MoE correction epilogue avoids the problem by doing a ONE-WAY trip
(TMEM→REGS→SMEM→GMEM, never writes back to TMEM). FMHA needs O in TMEM
for PV accumulation between kt iterations, so one-way doesn't help.

Production path for multi-KV-tile: Python KV merge (already verified,
cos 0.999998 for s_k up to 1024). Run kernel per 128-token segment.

Future: restructure PV to accumulate into REGS/SMEM instead of TMEM,
enabling the one-way correction epilogue pattern.
2026-05-26 19:50:11 +00:00
40cbf0c223 Add D1.5 paired-atom O rescale test (s_k=256/384, hd=64/128) 2026-05-26 19:46:19 +00:00
43f0b5d1e8 D1.5: Fix O rescale with paired atoms (incremental approach)
Keep epilogue_tma_store for final output (proven path).
Only fix the multi-KV-tile O rescale using paired atoms from
epilogue_tmem_copy_and_partition. The paired atoms share addressing,
making the TMEM->REGS->modify->TMEM cycle lossless.

Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128)
is completely unaffected — zero regression risk.

Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred
until we can address the MLIR compilation time issue.
2026-05-26 19:34:26 +00:00
4bb0e063cc D1.5: Replace broken TMEM round-trip with correction epilogue (paired atoms)
Replace hand-constructed Ld32x32bOp/St32x32bOp TMEM round-trip with the
proven correction epilogue pattern from fused_swiglu.py:

1. O rescale (kt>0): TMEM→REGS (paired load), multiply by acc_scale,
   REGS→TMEM (paired store via retile_to_S). No layout mismatch.

2. Final O output: One-way TMEM→REGS→SMEM→GMEM using
   epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition
   + TMA partition. Register-level normalization (divide by row_sum)
   or raw BF16 cast for D5a path.

This fixes both D1.5 issues:
- Issue 1: TMEM round-trip corruption (hand-constructed atoms)
- Issue 2: O rescale for multi-KV-tile (kt>0)

Supports normalize=True (in-kernel) and normalize=False (D5a external).
Uses epilog_sync_bar + c_pipe for SMEM→GMEM, replacing epilogue_tma_store.
2026-05-26 19:11:19 +00:00
f97aee6eed plan update 2026-05-26 19:00:22 +00:00
487d960a6a D5c multi-tile: VERIFIED cos 0.999996 with Python KV merge + sink bias
Both segments (compressed+SWA with n_comp=96, and SWA-only with n_comp=0)
pass individually at cos 0.999996. The Python KV merge produces the
correct combined attention at cos 0.999996.

Key: n_comp is compile-time, so separate kernel instances are needed
for segments with different n_comp values. Production code would use
a kernel cache keyed on (n_comp, apply_sink_bias, ...).
2026-05-26 15:40:45 +00:00
c9eab3c7e0 diag: rewrite multi-tile test with explicit per-segment compile and reference 2026-05-26 15:39:39 +00:00
7f983fb855 diag: add direct segment 0 test to compare with run_segment 2026-05-26 15:37:06 +00:00
2a5f9dc6e3 fix: simplify run_segment to use full hd V tensor (was incorrectly splitting by pv_n_tile) 2026-05-26 15:34:57 +00:00
aa2df1a202 diag: test n_comp=96 with sink bias directly 2026-05-26 15:33:38 +00:00
25b236fe00 diag: test D5c multi-tile with no sink bias to isolate issue 2026-05-26 15:31:38 +00:00
a3989929de diag: per-segment reference comparison for D5c multi-tile 2026-05-26 15:29:52 +00:00
bbc29945e8 diag: add per-segment debug output for D5c multi-tile 2026-05-26 15:28:19 +00:00
e64392f1ac D5c: add apply_sink_bias flag (independent of n_comp)
For all-SWA segments (n_comp=0), sink bias still needs to be applied
to all positions. The apply_sink_bias flag controls compilation of
the sink bias code path, independent of n_comp offset.
2026-05-26 15:26:52 +00:00
60b6f8ee79 fix: use separate kernels for segments with/without compressed KV (n_comp is compile-time) 2026-05-26 15:23:21 +00:00
2efd15c852 fix: correct swa_len_local calculation per segment for D5c multi-tile 2026-05-26 15:22:03 +00:00
3abcc7ff09 D5c: multi-tile test using Python KV merge with sink bias 2026-05-26 15:20:45 +00:00
57a8316bc1 update README: D5c sink bias DONE (cos 0.999996, single KV tile) 2026-05-26 15:17:10 +00:00
8f7df4d8b5 fix: mRowSums dummy tensor must match mLSE layout (3D, not 1D) 2026-05-26 15:14:35 +00:00
ffc4b542bc D5c: use single KV tile (s_k=128) to avoid broken O rescale
The D5c sink bias logic is VERIFIED CORRECT (cos 0.999996).
Multi-KV-tile fails due to known D1.5 O rescale bug (TMEM round-trip).
Using s_k=128 avoids the broken code path. Multi-tile support requires
the D1.5 correction epilog fix.
2026-05-26 15:11:50 +00:00
fc0f4bcf23 diag: test D5c with single KV tile (s_k=128) to isolate O rescale issue 2026-05-26 15:09:49 +00:00
e5381b7312 diag: add baseline test (s_k=256 D3 mask, no sink bias) to isolate D5c issue 2026-05-26 15:08:40 +00:00
016edbcc97 D5c: add row_sum output for proper external normalization
The kernel's O_unnorm is max-shifted (divided by 2^row_max), so
O_norm != O_unnorm * exp(-LSE). Instead, O_norm = O_unnorm / row_sum.
Added mRowSums output tensor to enable correct normalization.
2026-05-26 15:07:22 +00:00
31e6426049 fix: normalize kernel output using per-row LSE for D5c test 2026-05-26 15:04:47 +00:00
014d647ba3 fix: sink bias domain correction — add attn_sink/scale to raw logits
The softmax scales by scale_log2 = scale * log2(e). Adding sink_val to
raw logits causes it to be scaled too. Fix: add sink_val/scale instead,
so after scaling: (sink_val/scale) * scale_log2 = sink_val * log2(e).
This correctly multiplies attention weights by exp(sink_val).
2026-05-26 15:03:49 +00:00
dbdbcecadc fix: sink_bias must be pre-converted to CuTe tensor before passing to compile 2026-05-26 15:02:43 +00:00
04b66e0f9c fix: test_d5c use float for attn_sink in reference functions 2026-05-26 15:01:31 +00:00
9d64434954 D5c: add sink bias (attn_sink) logit modification to FMHA kernel
- Add n_comp parameter: compressed KV length, sink bias applies to positions >= n_comp
- Add sink_bias parameter: per-head FP32 logit bias for SWA positions
- D3 mask updated: kv_pos >= n_comp + swa_len (backward compatible when n_comp=0)
- D4 causal mask updated: compare SWA-relative position (kv_pos - n_comp) with m_coord
- Mathematical insight: sink merge = single softmax over [S_comp, S_swa + attn_sink]
- Add test_d5c_fused.py with combined KV + sink bias test
2026-05-26 14:59:52 +00:00
60a6f2d296 update README: D5b per-row LSE, D3/D4 DONE 2026-05-26 11:03:57 +00:00
865eed0d33 cleanup: remove debug test file 2026-05-26 11:03:45 +00:00
5b55cf0767 fix: k_seg is already 3D from slicing, don't add extra unsqueeze(-1) 2026-05-26 11:02:44 +00:00
375a682206 debug: isolated KV merge test 2026-05-26 11:01:22 +00:00
2252d7c865 fix: make K/V segments contiguous before passing to kernel (TMA needs contiguous tensors) 2026-05-26 11:00:36 +00:00
5407dc768a test: minor comment fix in D5b test 2026-05-26 10:59:51 +00:00
bb7ec341cb fix: D5b test uses reference attn_sum for normalization, correct D5 merge formula
- exp(LSE) != row_sum (it's row_sum * exp(max(S*scale)))
- Normalize using reference attn_sum (same as other tests)
- D5 merge uses normalized O + LSE: O = sum(exp(lse)*O_norm)/sum(exp(lse))
- Added 4-tile KV merge test (s_k=512)
2026-05-26 10:59:04 +00:00
6c73069cb9 D5b: Per-row LSE output + Python KV merge test
- Fix LSE output: all 128 rows now write (mLSE[sfw_idx, 0, 0])
  instead of only row 0 (mLSE[0])
- Each softmax thread (sfw_idx 0..127) independently writes its LSE
- This enables accurate Python-side KV merge for multi-KV-tile
- New test: test_d5b_perrow_lse.py with LSE verification + KV merge
2026-05-26 10:57:54 +00:00
4656fa81f9 update README: D3 and D4 status DONE 2026-05-26 10:56:57 +00:00
24993428a2 fix: D4 test reference computation only applies causal mask when is_causal=True 2026-05-26 10:56:04 +00:00
e3e01071f4 fix: swa_len as Int32 scalar instead of CuTe tensor
CuTeDSL @cute.kernel cannot handle dynamic-shape tensors as parameters.
Pass swa_len as Int32 scalar instead of a 1D tensor.
This works for batch_size=1 (current config).
Updated D3 and D4 tests to pass swa_len as int.
2026-05-26 10:54:41 +00:00
df84420414 fix: add is_causal to FmhaKernel __init__ signature 2026-05-26 10:53:14 +00:00
841a3e87b2 D4: Causal mask on SWA branch
- Add is_causal flag to FmhaKernel constructor
- Mask positions where k_coord > m_coord to -inf (causal attention)
- Combined with D3 SWA mask: both conditions use OR logic
- Same tTMEM_LOADcS coordinate mapping as D3
- const_expr guarded: zero overhead when is_causal=False
- New test: test_d4_causal_mask.py with causal + combined masking
2026-05-26 10:52:30 +00:00
b6b581777a D3: In-kernel SWA sequence length masking
- Add apply_swa_mask flag to FmhaKernel constructor
- After TMEM load of S, use tTMEM_LOADcS coordinates to map register
  fragment positions to (row, col) in QK matrix
- Mask positions >= swa_lens[batch_idx] to -inf before softmax
- Supports multi-KV-tile (kt*128 + k_coord for absolute position)
- swa_lens parameter passed as CuTe tensor, indexed by block_idx_z
- Dummy tensor (max int) when swa_lens=None (no masking)
- New test: test_d3_inkernel_mask.py with proper in-kernel masking
- Replaces pre-masking approach (BF16 min on K) which can't produce -inf
2026-05-26 10:51:23 +00:00
d6a56342cc D3: add swa_lens parameter to FmhaKernel (in-kernel masking TBD) 2026-05-25 17:31:01 +00:00
e9f476b6dc fix typo: from_dlset → from_dlpack 2026-05-25 17:28:43 +00:00
f278348f44 D3: SWA mask with BF16 min pre-masking approach (K[invalid]=BF16_MIN → scores≈-inf) 2026-05-25 17:27:35 +00:00
cfbeb9c454 D3: SWA mask test with zero-masking approach (pre-mask K/V in Python) 2026-05-25 17:23:03 +00:00
68cb0236b5 D3: add SWA sequence length mask test (reference oracle + full-window regression) 2026-05-25 17:20:53 +00:00
7f69979c5f D1.5: add multi-KV-tile attention test with Python KV merge
- Splits K/V into 128-token segments
- Runs FMHA per segment, merges with exp(lse) weighted sum
- Tests: s_k=256 (2 tiles), s_k=512 (4 tiles)
- Uses reference attn_sum for normalization
2026-05-25 17:18:50 +00:00