Commit Graph

1224 Commits

Author SHA1 Message Date
b39d7f1a14 Try cute.copy(tma_c, sC_flat, gO) directly 2026-05-27 05:29:51 +00:00
2af767a90c Try full tensor TMA copy without slicing 2026-05-27 05:28:43 +00:00
7d14a2f764 sC_flat with simple (128, pv_n_tile) layout for full epi_tile coverage 2026-05-27 05:27:51 +00:00
6fb0e6a417 Use sC_flat (non-swizzled epi_s layout) for TMA store from SMEM accumulator 2026-05-27 05:26:50 +00:00
4a2a06f9e1 Fix gO slice: use separate Int32(0) instead of tuple 2026-05-27 05:25:33 +00:00
bf36979a8d Use CUTLASS FMHA reference pattern for sC->GMEM TMA store (flat_divide + tma_partition) 2026-05-27 05:24:39 +00:00
97bc6d8d2f Add c_direct GMEM tensor for direct writes in SMEM accumulator path 2026-05-27 05:15:47 +00:00
3d349b497b SME accumulator: direct GMEM write from sO_acc (bypass TMA for multi-kt) 2026-05-27 05:14:31 +00:00
7d1e0a605d Different coordinate dims for bSG_sC (2D) and bSG_gC (3D) 2026-05-27 05:13:38 +00:00
75b272c5f2 2D coordinate for bSG_sC TMA copy 2026-05-27 05:12:58 +00:00
72dff90165 3D coordinate for bSG_sC/gC TMA copy 2026-05-27 05:12:11 +00:00
b8b6e8cc0b Slice bSG_gC MMA tile coords for TMA copy 2026-05-27 05:11:26 +00:00
754740d5e5 Try bSG_sC[(None, 0)] for TMA copy coordinate 2026-05-27 05:10:40 +00:00
23a2b49daf Add SMEM accumulator for n_kv_tiles>1: O load from TMEM, accumulate in sO_acc, TMA store from sC 2026-05-27 05:09:54 +00:00
a858ed1c14 Fix test: normalize=False for un-normalized O comparison 2026-05-27 05:06:52 +00:00
2e262d2b99 Reset fmha_smem_acc.py to working fmha.py base 2026-05-27 05:05:41 +00:00
b43ffe9dac Guard sO_acc allocation/zero-init with n_kv_tiles>1 2026-05-27 05:05:01 +00:00
101840c78c Guard SMEM accumulation with n_kv_tiles>1 to avoid TMEM destructive read 2026-05-27 05:02:51 +00:00
02a34512cb Use epilogue_tma_store for n_kv_tiles=1; TODO for multi-tile 2026-05-27 05:01:39 +00:00
4652cab8b4 Fix: 3D coords for TMA copy (bSG_sC has 3 modes) 2026-05-27 05:00:39 +00:00
b0ebf41ee3 Slice bSG_gC with mma_tile_coord (like epilogue_tma_store) 2026-05-27 05:00:04 +00:00
eb0bf0cce0 Fix TMA store: use bSG_sC[(None,0)] indexing pattern from epilogue_tma_store 2026-05-27 04:59:29 +00:00
7ea77a121f Use cpasync.tma_partition for SMEM->GMEM TMA store (like epilogue_tma_store) 2026-05-27 04:58:47 +00:00
e614d0894c Clean up SMEM acc epilogue: flat indexing sO_acc->sC, TMA store from sC_s0 2026-05-27 04:57:40 +00:00
1724eeb8ec Fix TMA store: use epi_s view of sC for proper layout compatibility 2026-05-27 04:55:18 +00:00
3a7d87adba Fix test_smem_acc: use keyword args for lse/row_sums 2026-05-27 04:54:23 +00:00
6a621bdf64 D1.5: SMEM accumulator FMHA kernel — one-way TMEM→REGS→SMEM, no round-trip
TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY BROKEN.
Even NO-OP (multiply by 1.0) corrupts data.

New approach:
- PV always ACCUMULATE=False (fresh TMEM each kt)
- After pv_done_bar: one-way Ld32x32bOp load O_kt from TMEM→REGS
- Coordinate-indexed SMEM accumulation: sO_acc = acc_scale * sO_acc + O_kt
- sO_acc: FP32 [128, pv_n_tile] row-major (32KB at hd=64, 64KB at hd=128)
- Final: normalize, cast BF16, write to sC, TMA store to GMEM
2026-05-27 04:53:40 +00:00
81acf1593c Revert "D1.5: WIP SMEM accumulator — framework in place, accumulation logic TODO"
This reverts commit 72d88af400.
2026-05-27 02:17:26 +00:00
72d88af400 D1.5: WIP SMEM accumulator — framework in place, accumulation logic TODO
Added epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition
setup for multi-KV-tile O rescale. The one-way TMEM→REGS→SMEM pipeline
is wired up, but the SMEM-level accumulation (load-previous, scale, add,
store-back) needs implementation. Currently falls through to Python KV merge.
2026-05-27 02:15:23 +00:00
a6da93ddfb Revert "D1.5: Try O rescale with tCtO_base layout (epilogue-proven TMEM addressing)"
This reverts commit 79e2eb3b42.
2026-05-27 02:12:20 +00:00
79e2eb3b42 D1.5: Try O rescale with tCtO_base layout (epilogue-proven TMEM addressing)
Previous attempts used tOtO0 (from pv_thr.make_fragment_C) and corrupted data.
This version uses tCtO_base (from pv_mma.make_fragment_C) which is the SAME
tensor the epilogue successfully reads O from. Both load and store atoms built
from same tCtO_i via composition — CUTLASS correction_rescale pattern.
2026-05-27 02:10:39 +00:00
f94978ffa7 D1.5: Prepare for SMEM accumulator implementation
- Added epilogue utility imports (transform_partitioned_tensor_layout, etc.)
- Re-added pv_done_bar for SMEM accumulator synchronization
- Backed up current fmha.py as fmha_backup_v2.py
- SMEM accumulator approach: one-way TMEM→REGS→SMEM per kt, accumulate in FP32 SMEM
2026-05-26 21:00:41 +00:00
afb93eae22 D1.5: Revert broken TMEM round-trip O rescale, document as fundamentally broken
TMEM round-trip via Ld32x32bOp/St32x32bOp corrupts O accumulator data
even with CUTLASS correction_rescale pattern. All variants tested:
- Repetition(16) + composition (CUTLASS exact pattern) — BROKEN
- Repetition(32) + composition — BROKEN
- Repetition(16) raw layout (no composition) — BROKEN
Even NO-OP (multiply by 1.0) produces catastrophically wrong results.

Production path remains Python KV merge (cos 0.999998 for s_k up to 1024).
Next: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt).
2026-05-26 20:55:16 +00:00
42c5793add D1.5: Add isolated round-trip test comparing s_k=128 vs s_k=256 with NOOP rescale 2026-05-26 20:45:58 +00:00
e35b30dae6 D1.5 debug: try corr_tile_size=32 for O rescale round-trip 2026-05-26 20:43:29 +00:00
20ed6d5114 D1.5: Add TMEM load fence before PV with ACCUMULATE, revert debug rescale factor
The MMA warp needs fence_view_async_tmem_load() before PV[kt>0] to ensure
the rescaled O values are visible. NamedBarrier synchronizes warps but may
not guarantee TMEM visibility without an explicit fence.
2026-05-26 20:31:28 +00:00
34d64137ec D1.5 debug: force rescale_factor=0.5 to test if round-trip code executes 2026-05-26 20:29:34 +00:00
3be708d923 D1.5 debug: add NOOP rescale test (acc_scale=1.0) to isolate TMEM round-trip corruption 2026-05-26 20:28:55 +00:00
c3648e4ebf D1.5 debug: add targeted s_k=256 rescale diagnostic test 2026-05-26 20:27:37 +00:00
bf2c7c8bb8 D1.5: Implement in-kernel O rescale via CUTLASS correction_rescale pattern
- Both load and store atoms built from SAME tOtO_i (composition-tiled)
- Same Repetition(corr_tile_size=16) for both copies
- pv_done_bar synchronization between MMA and softmax warps
- acc_scale computed per kt iteration, used to rescale O in TMEM
- const_expr(n_kv_tiles > 1) guards for zero overhead at s_k=128
- New test: test_d15_in_kernel_rescale.py (s_k=128/256/384)
- Minimal roundtrip test: test_tmem_roundtrip_minimal.py
2026-05-26 20:26:06 +00:00
064ececc9a Update docs: D1.5 TMEM round-trip fundamentally broken, Python KV merge is production path 2026-05-26 19:53:10 +00:00
2b4f4ce538 Remove broken D1.5 paired-atom test (TMEM round-trick is fundamentally broken) 2026-05-26 19:50:31 +00:00
ffb3e736bb D1.5: Revert broken paired-atom O rescale — TMEM round-trip fundamentally broken
Ld32x32bOp and St32x32bOp have different column mappings at the hardware
level. No layout transformation can fix this — the atoms themselves map
TMEM columns differently.

The MoE correction epilogue avoids the problem by doing a ONE-WAY trip
(TMEM→REGS→SMEM→GMEM, never writes back to TMEM). FMHA needs O in TMEM
for PV accumulation between kt iterations, so one-way doesn't help.

Production path for multi-KV-tile: Python KV merge (already verified,
cos 0.999998 for s_k up to 1024). Run kernel per 128-token segment.

Future: restructure PV to accumulate into REGS/SMEM instead of TMEM,
enabling the one-way correction epilogue pattern.
2026-05-26 19:50:11 +00:00
40cbf0c223 Add D1.5 paired-atom O rescale test (s_k=256/384, hd=64/128) 2026-05-26 19:46:19 +00:00
43f0b5d1e8 D1.5: Fix O rescale with paired atoms (incremental approach)
Keep epilogue_tma_store for final output (proven path).
Only fix the multi-KV-tile O rescale using paired atoms from
epilogue_tmem_copy_and_partition. The paired atoms share addressing,
making the TMEM->REGS->modify->TMEM cycle lossless.

Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128)
is completely unaffected — zero regression risk.

Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred
until we can address the MLIR compilation time issue.
2026-05-26 19:34:26 +00:00
4bb0e063cc D1.5: Replace broken TMEM round-trip with correction epilogue (paired atoms)
Replace hand-constructed Ld32x32bOp/St32x32bOp TMEM round-trip with the
proven correction epilogue pattern from fused_swiglu.py:

1. O rescale (kt>0): TMEM→REGS (paired load), multiply by acc_scale,
   REGS→TMEM (paired store via retile_to_S). No layout mismatch.

2. Final O output: One-way TMEM→REGS→SMEM→GMEM using
   epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition
   + TMA partition. Register-level normalization (divide by row_sum)
   or raw BF16 cast for D5a path.

This fixes both D1.5 issues:
- Issue 1: TMEM round-trip corruption (hand-constructed atoms)
- Issue 2: O rescale for multi-KV-tile (kt>0)

Supports normalize=True (in-kernel) and normalize=False (D5a external).
Uses epilog_sync_bar + c_pipe for SMEM→GMEM, replacing epilogue_tma_store.
2026-05-26 19:11:19 +00:00
f97aee6eed plan update 2026-05-26 19:00:22 +00:00
487d960a6a D5c multi-tile: VERIFIED cos 0.999996 with Python KV merge + sink bias
Both segments (compressed+SWA with n_comp=96, and SWA-only with n_comp=0)
pass individually at cos 0.999996. The Python KV merge produces the
correct combined attention at cos 0.999996.

Key: n_comp is compile-time, so separate kernel instances are needed
for segments with different n_comp values. Production code would use
a kernel cache keyed on (n_comp, apply_sink_bias, ...).
2026-05-26 15:40:45 +00:00
c9eab3c7e0 diag: rewrite multi-tile test with explicit per-segment compile and reference 2026-05-26 15:39:39 +00:00
7f983fb855 diag: add direct segment 0 test to compare with run_segment 2026-05-26 15:37:06 +00:00