3a25c7feff
Test multi-KV merge (2 segments) separately from multi-head
2026-05-27 06:54:16 +00:00
36a6f07a7e
Fix: unsqueeze k/v when dim==2
2026-05-27 06:52:43 +00:00
fc4172937c
Clean production wrapper: always normalize=False + KV merge
2026-05-27 06:51:14 +00:00
8f87109f86
Single-segment: use normalize=False + per-row normalization from row_sums
2026-05-27 06:48:56 +00:00
fe55bf23a0
Split single-segment (normalized) and multi-segment (KV merge) paths
2026-05-27 06:46:30 +00:00
e45b94c01b
Test: compare both normalized and un-normalized reference
2026-05-27 06:44:37 +00:00
b70ab2a6ee
Return o_accum directly (un-normalized merge result)
2026-05-27 06:42:58 +00:00
6111db571c
Match working test: don't pass row_sums to kernel
2026-05-27 06:41:44 +00:00
312ac52d15
Normalize O_accum by exp(lse) before returning
2026-05-27 06:39:36 +00:00
ddc701af9b
Use exact merge formula from working test_d1_kv_merge.py
2026-05-27 06:38:04 +00:00
8321ccf9c1
Fix production KV merge: use normalized O for log-sum-exp merge
2026-05-27 06:36:24 +00:00
98c93c1cd8
Stage E: production attention wrapper + Python KV merge, clean fmha_smem_acc
2026-05-27 06:34:10 +00:00
51e456df44
Slice MMA tile coords from tOgO for TMA copy
2026-05-27 05:39:42 +00:00
1caa737b09
Move sC_flat_staged creation before const_expr guard
2026-05-27 05:38:39 +00:00
3c9dbc0c5d
Staged sC_flat with (128, pv_n_tile//2, 2) to match TMA atom
2026-05-27 05:37:05 +00:00
de2028b106
Split sC_flat into staged layout to match TMA atom decomposition
2026-05-27 05:35:56 +00:00
a0e9f7534b
Use tCgC_epi (transformed) for GMEM side of TMA partition
2026-05-27 05:34:40 +00:00
b02e103ac0
Add c_simple GMEM tensor (non-dynamic) for SMEM accumulator TMA store
2026-05-27 05:33:30 +00:00
2438826eee
Use tma_partition with group_modes on both sC_flat and gO
2026-05-27 05:31:47 +00:00
603f52de78
Fix gO creation: use slice_(pv_mma_tiler) like fmha.py
2026-05-27 05:30:50 +00:00
b39d7f1a14
Try cute.copy(tma_c, sC_flat, gO) directly
2026-05-27 05:29:51 +00:00
2af767a90c
Try full tensor TMA copy without slicing
2026-05-27 05:28:43 +00:00
7d14a2f764
sC_flat with simple (128, pv_n_tile) layout for full epi_tile coverage
2026-05-27 05:27:51 +00:00
6fb0e6a417
Use sC_flat (non-swizzled epi_s layout) for TMA store from SMEM accumulator
2026-05-27 05:26:50 +00:00
4a2a06f9e1
Fix gO slice: use separate Int32(0) instead of tuple
2026-05-27 05:25:33 +00:00
bf36979a8d
Use CUTLASS FMHA reference pattern for sC->GMEM TMA store (flat_divide + tma_partition)
2026-05-27 05:24:39 +00:00
97bc6d8d2f
Add c_direct GMEM tensor for direct writes in SMEM accumulator path
2026-05-27 05:15:47 +00:00
3d349b497b
SME accumulator: direct GMEM write from sO_acc (bypass TMA for multi-kt)
2026-05-27 05:14:31 +00:00
7d1e0a605d
Different coordinate dims for bSG_sC (2D) and bSG_gC (3D)
2026-05-27 05:13:38 +00:00
75b272c5f2
2D coordinate for bSG_sC TMA copy
2026-05-27 05:12:58 +00:00
72dff90165
3D coordinate for bSG_sC/gC TMA copy
2026-05-27 05:12:11 +00:00
b8b6e8cc0b
Slice bSG_gC MMA tile coords for TMA copy
2026-05-27 05:11:26 +00:00
754740d5e5
Try bSG_sC[(None, 0)] for TMA copy coordinate
2026-05-27 05:10:40 +00:00
23a2b49daf
Add SMEM accumulator for n_kv_tiles>1: O load from TMEM, accumulate in sO_acc, TMA store from sC
2026-05-27 05:09:54 +00:00
a858ed1c14
Fix test: normalize=False for un-normalized O comparison
2026-05-27 05:06:52 +00:00
2e262d2b99
Reset fmha_smem_acc.py to working fmha.py base
2026-05-27 05:05:41 +00:00
b43ffe9dac
Guard sO_acc allocation/zero-init with n_kv_tiles>1
2026-05-27 05:05:01 +00:00
101840c78c
Guard SMEM accumulation with n_kv_tiles>1 to avoid TMEM destructive read
2026-05-27 05:02:51 +00:00
02a34512cb
Use epilogue_tma_store for n_kv_tiles=1; TODO for multi-tile
2026-05-27 05:01:39 +00:00
4652cab8b4
Fix: 3D coords for TMA copy (bSG_sC has 3 modes)
2026-05-27 05:00:39 +00:00
b0ebf41ee3
Slice bSG_gC with mma_tile_coord (like epilogue_tma_store)
2026-05-27 05:00:04 +00:00
eb0bf0cce0
Fix TMA store: use bSG_sC[(None,0)] indexing pattern from epilogue_tma_store
2026-05-27 04:59:29 +00:00
7ea77a121f
Use cpasync.tma_partition for SMEM->GMEM TMA store (like epilogue_tma_store)
2026-05-27 04:58:47 +00:00
e614d0894c
Clean up SMEM acc epilogue: flat indexing sO_acc->sC, TMA store from sC_s0
2026-05-27 04:57:40 +00:00
1724eeb8ec
Fix TMA store: use epi_s view of sC for proper layout compatibility
2026-05-27 04:55:18 +00:00
3a7d87adba
Fix test_smem_acc: use keyword args for lse/row_sums
2026-05-27 04:54:23 +00:00
6a621bdf64
D1.5: SMEM accumulator FMHA kernel — one-way TMEM→REGS→SMEM, no round-trip
...
TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY BROKEN.
Even NO-OP (multiply by 1.0) corrupts data.
New approach:
- PV always ACCUMULATE=False (fresh TMEM each kt)
- After pv_done_bar: one-way Ld32x32bOp load O_kt from TMEM→REGS
- Coordinate-indexed SMEM accumulation: sO_acc = acc_scale * sO_acc + O_kt
- sO_acc: FP32 [128, pv_n_tile] row-major (32KB at hd=64, 64KB at hd=128)
- Final: normalize, cast BF16, write to sC, TMA store to GMEM
2026-05-27 04:53:40 +00:00
81acf1593c
Revert "D1.5: WIP SMEM accumulator — framework in place, accumulation logic TODO"
...
This reverts commit 72d88af400 .
2026-05-27 02:17:26 +00:00
72d88af400
D1.5: WIP SMEM accumulator — framework in place, accumulation logic TODO
...
Added epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition
setup for multi-KV-tile O rescale. The one-way TMEM→REGS→SMEM pipeline
is wired up, but the SMEM-level accumulation (load-previous, scale, add,
store-back) needs implementation. Currently falls through to Python KV merge.
2026-05-27 02:15:23 +00:00
a6da93ddfb
Revert "D1.5: Try O rescale with tCtO_base layout (epilogue-proven TMEM addressing)"
...
This reverts commit 79e2eb3b42 .
2026-05-27 02:12:20 +00:00