Commit Graph

619 Commits

Author SHA1 Message Date
01621e1520 Diag: try runtime Int32(0+0) for kv_coord with cutlass.range 2026-05-22 17:57:58 +00:00
beecc4df47 Diag: use Python range() unrolling like stage C test 2026-05-22 17:56:59 +00:00
200430bd3f Fix diagnostic test: same Int32(kt) + n_kv_tiles fixes 2026-05-22 17:56:15 +00:00
c23ebd5b57 Try cutlass.range with Int32(kt) — now n_kv_tiles is Python int 2026-05-22 17:51:25 +00:00
4a41df51c4 FIX: n_kv_tiles as Python int (s_k//128) for range() unrolling
cute.size() returns a CuTeDSL symbol, not a Python int.
range() on a symbol can't iterate — the loop never unrolls.
Now n_kv_tiles is computed in __init__ as s_k // 128 (Python int).
2026-05-22 17:50:07 +00:00
70409636f7 Option 2: Python range() with Int32(kt) for TMA GMEM coord
cutlass.range traces once - kv_coord/kt are trace-time values,
not runtime loop-carried state. Python range() fully unrolls at
trace time, emitting distinct Int32(k) constants per iteration.
Int32(1) hardcoded already proved TMA CAN load from tile 1.
2026-05-22 17:47:43 +00:00
b55a38c4c3 Add example5: use cutlass.range induction variable as TMA GMEM coord 2026-05-22 17:47:10 +00:00
aacad257ea README: add fire_b200_test docs, update multi-tile blocker with real findings 2026-05-22 17:41:23 +00:00
93c28b9c29 Clean up debug prints, set kv_coord as Int32(0)
Key findings to relay to CUTLASS LLM:
- kv_coord=Int32(1) hardcode CHANGES the output (TMA CAN load from different tiles)
- kv_coord=Int32(0) + kv_coord += 1 does NOT increment at runtime
  (all multi-tile outputs identical to kv_coord=0)
- kv_coord=0 (plain Python int) also doesn't work
- Pipeline handle .count doesn't work either
- The TMA GMEM tile coordinate must be dynamic at kernel runtime,
  but CuTeDSL appears to constant-fold or not propagate the increment
2026-05-22 17:39:27 +00:00
1bba851911 DEBUG: try plain Python int kv_coord (like CUTLASS ref) 2026-05-22 17:34:30 +00:00
15b2a28d29 DEBUG: hardcode kv_coord=1 to test if TMA uses it 2026-05-22 17:32:53 +00:00
ff9ef6dcde DEBUG: try K slice (None,0,None,0) keeping mode 2 free 2026-05-22 17:30:06 +00:00
cec6f59d66 DEBUG: print tBgK/tVgV shapes before/after slice 2026-05-22 17:28:45 +00:00
8740ab5b27 Stage C: manual kv_coord + correct K GMEM slice + O rescale fence
Key fixes:
1. GMEM tile coord: manual Int32 kv_coord (not kvh.count)
2. K GMEM slice: (None,None,0,0) keeps mode 1 free (GMEM iter)
3. V GMEM slice: (None,0,None,0) keeps mode 2 free (GMEM iter)
4. Add fence_view_async_tmem_load before O rescale for visibility
2026-05-22 17:26:56 +00:00
f39c3a8e38 Add example4: manual kv_coord Int32 for GMEM tile indexing
Pipeline handle .count is NOT a GMEM coordinate — it's opaque pipeline
state. CUTLASS FMHA reference uses a manual Int32 counter (kv_coord)
for GMEM and handle.index for SMEM. .count is never used as a
coordinate anywhere in the reference.
2026-05-22 17:24:38 +00:00
711480a4fb README: add test harness instructions 2026-05-22 17:09:53 +00:00
83f8a13add run_test.sh: SIGKILL all children of screen session on cleanup
Deadlocked GPU processes ignore SIGHUP from screen -X quit.
Now kills the entire process group with SIGKILL, plus a catch-all
pkill for any python test_ processes.
2026-05-22 17:08:12 +00:00
f4cfeae262 Add check_log.sh convenience script 2026-05-22 17:07:23 +00:00
61122e10c6 Fix quoting in run_test.sh 2026-05-22 17:06:00 +00:00
af475affab Add run_test.sh harness (screen + log) 2026-05-22 17:05:43 +00:00
d0626e0434 FIX: only slice GMEM tensors (SMEM already 2D from tma_partition) 2026-05-22 16:57:31 +00:00
15a2fdadbc FIX: consistent GMEM/SMEM slicing for K and V TMA partitions
Both GMEM and SMEM sides must be sliced to the same rank for cute.copy.
K (QK MMA B-partition): slice [(None,None,0,0)] keeps modes 0,1
  - mode 1 = GMEM iteration, indexed by kvh.count
V (PV MMA B-partition): slice [(None,0,None,0)] keeps modes 0,2
  - mode 2 = GMEM iteration, indexed by kvh.count
Q: only 1 tile, (None,0,None,0) hardcode is fine.
2026-05-22 16:56:38 +00:00
5c1423b2c5 FIX: keep GMEM iteration dimension FREE in TMA K/V partition slices
Root cause of multi-tile failure: (None,0,None,0) slice hardcodes the
GMEM tile dimension to 0, so TMA always loads from tile 0 regardless
of kvh.count. K from QK MMA has GMEM iter at mode 1, V from PV MMA
has it at mode 2 (different layouts: K,D,L vs D,K,L).

Fix follows CUTLASS FMHA reference:
- K: tBgK[(None,None,None,0)] + tBgK[(None, kvh.count, None)]
- V: tVgV[(None,0,None,0)] + tVgV[(None, kvh.count)]
2026-05-22 16:51:57 +00:00
018c09644b Add diagnostic test for multi-tile TMA pipeline (identity softmax) 2026-05-22 16:47:08 +00:00
ec895e831e FIX: acc_scale was double-multiplying by scale_log2
row_max is already in log2(scaled) space (S * scale_log2), so
old_row_max - row_max_safe is the correct exponent for exp2.
The old code computed exp2(scale_log2 * (old_row_max - row_max_safe))
which is exp2(scale_log2^2 * (old_max_S - new_max_S)) — wrong.
2026-05-22 16:42:45 +00:00
2147cce95d Stage C: integrate example3 multi-tile fixes into unit test
- Combined K+V barrier (one acquire per kt, kvh.count == kt)
- O rescale for kt > 0 (online softmax O correction)
- final_o_bar sync (MMA signals before producer_tail)
- s_k as constructor param (compile-time for V layout)
- kv_tx_bytes covers both K and V transfers
- Test covers n=128, 256, 512, 1024
2026-05-22 16:39:45 +00:00
dad35818f3 README + MEMORY: update Stage C status to single-tile only, document multi-tile blocker
- Stage C works for n=128 (0.993) but multi-tile (n>128) is broken
- Root cause: tBgK slice hardcodes GMEM iteration to tile 0
- CuTeDSL TMA copy doesn't accept Python int as tile index
- Mike's combined K+V barrier fix compiles but deadlocks at runtime
- Fallback: kh.count // 2 (untested)
2026-05-22 16:32:31 +00:00
e4c82873bb FMHA Stage-C multi-tile: combined K+V barrier, final_o_bar, acc_pipe producer
Key changes from Mike:
1. Combined K+V TMA barrier: one acquire per kt, both cute.copys share
   kvh.barrier. kvh.count naturally == kt (no interleaving problem).
   tx_count = K_bytes + V_bytes. Also fixes the sK[0]/sV[1] slot quirk.
2. final_o_bar NamedBarrier: MMA .arrive() after acc_pipe.producer_tail;
   softmax .arrive_and_wait() before reading O for normalize. Prevents
   softmax racing MMA's PV[N-1] on the final O read.
3. acc_pipe producer in MMA: producer_acquire before loop, commit+advance
   after loop, producer_tail after. Consumer in epilogue as before.
4. O rescale re-enabled for kt>0 with acc_scale before softmax_done_bar.
2026-05-22 16:23:36 +00:00
39fa5b96b0 restore tBgK to kh.count indexing (single-tile working), add TODO for multi-tile
CuTeDSL TMA copy API doesn't support dynamic GMEM tile indexing.
kh.count works for single tile. For multi-tile, need to either:
1. Map pipeline count to tile index (kh.count // 2 for interleaved K/V)
2. Separate K and V into non-interleaved TMA loops
3. Use gK/gV layouts that iterate naturally with pipeline count

This is the architectural blocker for multi-tile FMHA.
2026-05-22 15:54:03 +00:00
f1854bab26 FIX: use unsliced tBgK with (None, kt, None, 0) for proper GMEM tile indexing
The pre-slice (None,0,None,0) hardcoded GMEM iteration to tile 0.
Instead, keep the original tBgK and index with (None, kt, None, 0)
inside the TMA loop, where kt selects the correct GMEM tile.
This preserves 2D rank matching with the SMEM tensor.
2026-05-22 15:52:56 +00:00
3d2cb0e52b CRITICAL FIX: keep GMEM iteration dim free in tBgK/tVgV slice
The slice (None,0,None,0) was hardcoding the GMEM iteration dim to 0,
meaning TMA always loaded K/V from tile 0 regardless of kt.
Changed to (None,None,None,0) to keep gmem_iter free,
then index with (None, kt, None) in the TMA copy loop.

This is the root cause of multi-tile failure: TMA was always reading
the first 128 tokens for ALL KV tiles.
2026-05-22 15:52:06 +00:00
a04b219f0f add explicit acc_pipe.consumer_wait before final normalize
Race condition: softmax reads O to normalize while MMA may still be
writing PV[N-1]. Single-tile wins by luck; multi-tile drifts.
Move acc_cons_st construction before the wait so epilogue reuses it.
2026-05-22 15:49:48 +00:00
ff27a261b1 FMHA Stage-C multi-tile: Fix 1 (s_k=n), Fix 2 (TMA kt indexing), Fix 3 (O rescale)
Fix 1: s_k must equal actual n. With s_k < n, v_fmha layout only spans
first s_k V tokens and TMA reads OOB on later tiles.

Fix 2: TMA producer indexes K and V by kt (loop variable), NOT by the
pipeline's interleaved count. The kv pipeline interleaves K and V, so
pipeline count goes 0,1,2,3 but GMEM tiles should be K[0],V[0],K[1],V[1].

Fix 3: Online O rescale before softmax_done_bar. When row_max grows,
O must be multiplied by exp2(old_max - new_max) before MMA starts next PV.
2026-05-22 15:41:14 +00:00
5a08b79364 Revert "debug: test 12w identity softmax with n=256 to verify multi-tile pipeline"
This reverts commit 6cf8702e3c.
2026-05-22 10:25:48 +00:00
6cf8702e3c debug: test 12w identity softmax with n=256 to verify multi-tile pipeline 2026-05-22 10:24:53 +00:00
a3c9af8fa3 debug: disable O rescaling to test multi-tile pipeline baseline 2026-05-22 10:23:37 +00:00
c175ec4f09 fix: revert to scaled row_max, use exp2(old_max - new_max) for O rescaling
row_max is in scaled domain (s_val * scale_log2). The O rescaling
should be exp2(old_max - new_max) without extra scale_log2 because
the max values already include the scaling factor.
2026-05-22 10:22:44 +00:00
35c8043064 fix: compute row_max from RAW S values, not scaled
row_max should be the max of the raw QK scores, not pre-scaled.
The scale_log2 is applied during exp2 and rescaling, not stored in row_max.
This fixes the double-scaling bug that broke multi-tile O rescaling.
2026-05-22 10:21:50 +00:00
f9f5647eaa fix: missing newline after self.s_k = s_k 2026-05-22 10:20:35 +00:00
e0c320929a fix: add s_k param to FmhaV3StageC, use self.s_k for V FMHA reconstruction 2026-05-22 10:19:49 +00:00
fb4ffd8cf7 Stage C: add online O rescaling for multi-tile KV + test n=256
- Move O TMEM load/store setup before softmax loop
- After P store: rescale O in TMEM by exp2((old_max - new_max) * scale)
- Only rescale for kt > 0 (first tile has no prior O to rescale)
- Use same TMEM load/modify/store pattern as final normalization
- Test both n=128 (1 tile) and n=256 (2 tiles)
2026-05-22 10:19:08 +00:00
94b0d97107 fix: add epilogue warp to tmem_bar, restore wait_for_alloc in epilogue
The epilogue needs tmem_ptr for epilogue_tma_store.
It must be part of the tmem alloc barrier to synchronize.
2026-05-22 10:17:02 +00:00
65e52f5934 fix: add softmax_done_bar to synchronize MMA PV with softmax P production
MMA must wait for softmax to produce P in TMEM before starting PV.
Without this, MMA reads stale P data from TMEM, causing deadlock.
softmax_done_bar: softmax warps arrive after P store, MMA waits before PV.
2026-05-22 10:15:26 +00:00
ea687980af fix: epilogue warp self-signals acc_pipe producer before consuming 2026-05-22 10:11:55 +00:00
19b742f365 fix: remove duplicate tmem free from epilogue (MMA warp handles dealloc) 2026-05-22 10:05:52 +00:00
0a3815049f fix: add acc_pipe pipeline for epilogue, matching 12w pattern
- Add acc_bar to SS struct
- Create acc_pipe (full pipeline) before if blocks
- Pass acc_pipe to epilogue_tma_store (needs full pipeline, not participant)
2026-05-22 10:03:08 +00:00
59f4d8a469 fix: epilogue_warp_id must be tuple for epilogue_tma_store, check with [0] 2026-05-22 09:59:20 +00:00
6ba12b7890 fix: epilogue warp reuse mma_corr_cons pipeline instead of creating new one from st 2026-05-22 09:56:18 +00:00
540399eca3 fix: define cS and tScS in correction warps (not visible across if blocks) 2026-05-22 09:52:59 +00:00
ee859099bd fix: correct @cute.kernel indentation 2026-05-22 09:49:36 +00:00