Ld32x32bOp and St32x32bOp have different column mappings at the hardware
level. No layout transformation can fix this — the atoms themselves map
TMEM columns differently.
The MoE correction epilogue avoids the problem by doing a ONE-WAY trip
(TMEM→REGS→SMEM→GMEM, never writes back to TMEM). FMHA needs O in TMEM
for PV accumulation between kt iterations, so one-way doesn't help.
Production path for multi-KV-tile: Python KV merge (already verified,
cos 0.999998 for s_k up to 1024). Run kernel per 128-token segment.
Future: restructure PV to accumulate into REGS/SMEM instead of TMEM,
enabling the one-way correction epilogue pattern.
Keep epilogue_tma_store for final output (proven path).
Only fix the multi-KV-tile O rescale using paired atoms from
epilogue_tmem_copy_and_partition. The paired atoms share addressing,
making the TMEM->REGS->modify->TMEM cycle lossless.
Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128)
is completely unaffected — zero regression risk.
Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred
until we can address the MLIR compilation time issue.
Replace hand-constructed Ld32x32bOp/St32x32bOp TMEM round-trip with the
proven correction epilogue pattern from fused_swiglu.py:
1. O rescale (kt>0): TMEM→REGS (paired load), multiply by acc_scale,
REGS→TMEM (paired store via retile_to_S). No layout mismatch.
2. Final O output: One-way TMEM→REGS→SMEM→GMEM using
epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition
+ TMA partition. Register-level normalization (divide by row_sum)
or raw BF16 cast for D5a path.
This fixes both D1.5 issues:
- Issue 1: TMEM round-trip corruption (hand-constructed atoms)
- Issue 2: O rescale for multi-KV-tile (kt>0)
Supports normalize=True (in-kernel) and normalize=False (D5a external).
Uses epilog_sync_bar + c_pipe for SMEM→GMEM, replacing epilogue_tma_store.
Both segments (compressed+SWA with n_comp=96, and SWA-only with n_comp=0)
pass individually at cos 0.999996. The Python KV merge produces the
correct combined attention at cos 0.999996.
Key: n_comp is compile-time, so separate kernel instances are needed
for segments with different n_comp values. Production code would use
a kernel cache keyed on (n_comp, apply_sink_bias, ...).
For all-SWA segments (n_comp=0), sink bias still needs to be applied
to all positions. The apply_sink_bias flag controls compilation of
the sink bias code path, independent of n_comp offset.
The D5c sink bias logic is VERIFIED CORRECT (cos 0.999996).
Multi-KV-tile fails due to known D1.5 O rescale bug (TMEM round-trip).
Using s_k=128 avoids the broken code path. Multi-tile support requires
the D1.5 correction epilog fix.
The softmax scales by scale_log2 = scale * log2(e). Adding sink_val to
raw logits causes it to be scaled too. Fix: add sink_val/scale instead,
so after scaling: (sink_val/scale) * scale_log2 = sink_val * log2(e).
This correctly multiplies attention weights by exp(sink_val).
- Fix LSE output: all 128 rows now write (mLSE[sfw_idx, 0, 0])
instead of only row 0 (mLSE[0])
- Each softmax thread (sfw_idx 0..127) independently writes its LSE
- This enables accurate Python-side KV merge for multi-KV-tile
- New test: test_d5b_perrow_lse.py with LSE verification + KV merge
CuTeDSL @cute.kernel cannot handle dynamic-shape tensors as parameters.
Pass swa_len as Int32 scalar instead of a 1D tensor.
This works for batch_size=1 (current config).
Updated D3 and D4 tests to pass swa_len as int.
- Add is_causal flag to FmhaKernel constructor
- Mask positions where k_coord > m_coord to -inf (causal attention)
- Combined with D3 SWA mask: both conditions use OR logic
- Same tTMEM_LOADcS coordinate mapping as D3
- const_expr guarded: zero overhead when is_causal=False
- New test: test_d4_causal_mask.py with causal + combined masking
- Add apply_swa_mask flag to FmhaKernel constructor
- After TMEM load of S, use tTMEM_LOADcS coordinates to map register
fragment positions to (row, col) in QK matrix
- Mask positions >= swa_lens[batch_idx] to -inf before softmax
- Supports multi-KV-tile (kt*128 + k_coord for absolute position)
- swa_lens parameter passed as CuTe tensor, indexed by block_idx_z
- Dummy tensor (max int) when swa_lens=None (no masking)
- New test: test_d3_inkernel_mask.py with proper in-kernel masking
- Replaces pre-masking approach (BF16 min on K) which can't produce -inf