Commit Graph

2 Commits

Author SHA1 Message Date
bf2c7c8bb8 D1.5: Implement in-kernel O rescale via CUTLASS correction_rescale pattern
- Both load and store atoms built from SAME tOtO_i (composition-tiled)
- Same Repetition(corr_tile_size=16) for both copies
- pv_done_bar synchronization between MMA and softmax warps
- acc_scale computed per kt iteration, used to rescale O in TMEM
- const_expr(n_kv_tiles > 1) guards for zero overhead at s_k=128
- New test: test_d15_in_kernel_rescale.py (s_k=128/256/384)
- Minimal roundtrip test: test_tmem_roundtrip_minimal.py
2026-05-26 20:26:06 +00:00
43f0b5d1e8 D1.5: Fix O rescale with paired atoms (incremental approach)
Keep epilogue_tma_store for final output (proven path).
Only fix the multi-KV-tile O rescale using paired atoms from
epilogue_tmem_copy_and_partition. The paired atoms share addressing,
making the TMEM->REGS->modify->TMEM cycle lossless.

Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128)
is completely unaffected — zero regression risk.

Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred
until we can address the MLIR compilation time issue.
2026-05-26 19:34:26 +00:00