Added epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition
setup for multi-KV-tile O rescale. The one-way TMEM→REGS→SMEM pipeline
is wired up, but the SMEM-level accumulation (load-previous, scale, add,
store-back) needs implementation. Currently falls through to Python KV merge.
Previous attempts used tOtO0 (from pv_thr.make_fragment_C) and corrupted data.
This version uses tCtO_base (from pv_mma.make_fragment_C) which is the SAME
tensor the epilogue successfully reads O from. Both load and store atoms built
from same tCtO_i via composition — CUTLASS correction_rescale pattern.
TMEM round-trip via Ld32x32bOp/St32x32bOp corrupts O accumulator data
even with CUTLASS correction_rescale pattern. All variants tested:
- Repetition(16) + composition (CUTLASS exact pattern) — BROKEN
- Repetition(32) + composition — BROKEN
- Repetition(16) raw layout (no composition) — BROKEN
Even NO-OP (multiply by 1.0) produces catastrophically wrong results.
Production path remains Python KV merge (cos 0.999998 for s_k up to 1024).
Next: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt).
The MMA warp needs fence_view_async_tmem_load() before PV[kt>0] to ensure
the rescaled O values are visible. NamedBarrier synchronizes warps but may
not guarantee TMEM visibility without an explicit fence.
- Both load and store atoms built from SAME tOtO_i (composition-tiled)
- Same Repetition(corr_tile_size=16) for both copies
- pv_done_bar synchronization between MMA and softmax warps
- acc_scale computed per kt iteration, used to rescale O in TMEM
- const_expr(n_kv_tiles > 1) guards for zero overhead at s_k=128
- New test: test_d15_in_kernel_rescale.py (s_k=128/256/384)
- Minimal roundtrip test: test_tmem_roundtrip_minimal.py
Ld32x32bOp and St32x32bOp have different column mappings at the hardware
level. No layout transformation can fix this — the atoms themselves map
TMEM columns differently.
The MoE correction epilogue avoids the problem by doing a ONE-WAY trip
(TMEM→REGS→SMEM→GMEM, never writes back to TMEM). FMHA needs O in TMEM
for PV accumulation between kt iterations, so one-way doesn't help.
Production path for multi-KV-tile: Python KV merge (already verified,
cos 0.999998 for s_k up to 1024). Run kernel per 128-token segment.
Future: restructure PV to accumulate into REGS/SMEM instead of TMEM,
enabling the one-way correction epilogue pattern.
Keep epilogue_tma_store for final output (proven path).
Only fix the multi-KV-tile O rescale using paired atoms from
epilogue_tmem_copy_and_partition. The paired atoms share addressing,
making the TMEM->REGS->modify->TMEM cycle lossless.
Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128)
is completely unaffected — zero regression risk.
Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred
until we can address the MLIR compilation time issue.
Replace hand-constructed Ld32x32bOp/St32x32bOp TMEM round-trip with the
proven correction epilogue pattern from fused_swiglu.py:
1. O rescale (kt>0): TMEM→REGS (paired load), multiply by acc_scale,
REGS→TMEM (paired store via retile_to_S). No layout mismatch.
2. Final O output: One-way TMEM→REGS→SMEM→GMEM using
epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition
+ TMA partition. Register-level normalization (divide by row_sum)
or raw BF16 cast for D5a path.
This fixes both D1.5 issues:
- Issue 1: TMEM round-trip corruption (hand-constructed atoms)
- Issue 2: O rescale for multi-KV-tile (kt>0)
Supports normalize=True (in-kernel) and normalize=False (D5a external).
Uses epilog_sync_bar + c_pipe for SMEM→GMEM, replacing epilogue_tma_store.
Both segments (compressed+SWA with n_comp=96, and SWA-only with n_comp=0)
pass individually at cos 0.999996. The Python KV merge produces the
correct combined attention at cos 0.999996.
Key: n_comp is compile-time, so separate kernel instances are needed
for segments with different n_comp values. Production code would use
a kernel cache keyed on (n_comp, apply_sink_bias, ...).