Commit Graph

637 Commits

Author SHA1 Message Date
a38c8d4140 Real softmax test built on working identity diag 2026-05-22 18:55:31 +00:00
eea70fbda8 CRITICAL FIX: remove extra scale_log2 in softmax (minus_row_max and acc_scale) 2026-05-22 18:52:58 +00:00
95cc5e6ee5 FIX: K slice (None,None,0,0) like working diag 2026-05-22 18:51:22 +00:00
364e64fdeb Diag: identity softmax on example6 pipeline to isolate softmax bug 2026-05-22 18:49:33 +00:00
454445dd2a Quick test: working v3 with n=256 multi-tile 2026-05-22 18:47:54 +00:00
c7514de0ae DEBUG: add version marker to confirm code changes are running 2026-05-22 18:47:05 +00:00
20aa6b4d90 CRITICAL FIX: TMA pre-slice (None,0,None,0) → (None,None,0,0) to keep GMEM tile dim free 2026-05-22 18:45:55 +00:00
5ac4f483de Diag: TMA shapes with hardcoded major modes 2026-05-22 18:43:55 +00:00
f4abc4c115 Diag: simplified TMA shape analysis 2026-05-22 18:43:30 +00:00
1e26764664 Diag: print TMA partition shapes for multi-tile debugging 2026-05-22 18:43:05 +00:00
3b46abdfc0 FIX: Use Python range() in TMA warp for concrete per-iteration GMEM coords 2026-05-22 18:26:31 +00:00
bb1d86f954 FIX: Force SSA GMEM coord via n_kv_tiles - n_kv_tiles instead of cutlass.range kt 2026-05-22 18:25:13 +00:00
67b48bdedd FIX: TMEM offset bug in O rescale/normalize — use tOtO0.iterator not tOtO.iterator 2026-05-22 18:23:46 +00:00
877076343c Diag: test n=384 (3 tiles) to find crash boundary 2026-05-22 18:07:07 +00:00
cd72598560 Diag: test all sizes 128-1024 2026-05-22 18:06:28 +00:00
06852ba826 DEBUG: disable O rescale to isolate NaN cause 2026-05-22 18:05:46 +00:00
b540d48e44 Add NaN/inf checking to stage C test 2026-05-22 18:01:11 +00:00
5385ea7a32 CRITICAL FIX: K GMEM slice (None,None,0,0) not (None,0,None,0)
K from QK MMA B-partition has GMEM iter at mode 1, NOT mode 2.
(None,0,None,0) hardcodes mode 1 to 0 → TMA always loads tile 0.
(None,None,0,0) keeps mode 1 free → correct multi-tile loading.

Proof: diag n=256 went from cos 0.711 → 0.999999 with this one change.
2026-05-22 17:59:57 +00:00
dd27dc7f7a Diag: try K slice (None,None,0,0) keeping mode 1 (CUTLASS ref style) 2026-05-22 17:59:01 +00:00
18ca9d4661 Diag: try runtime Int32(0+0) for kv_coord with cutlass.range 2026-05-22 17:57:58 +00:00
f0832b585b Diag: use Python range() unrolling like stage C test 2026-05-22 17:56:59 +00:00
19933b7d65 Fix diagnostic test: same Int32(kt) + n_kv_tiles fixes 2026-05-22 17:56:15 +00:00
ee2f022aa0 Try cutlass.range with Int32(kt) — now n_kv_tiles is Python int 2026-05-22 17:51:25 +00:00
fd98b3387e FIX: n_kv_tiles as Python int (s_k//128) for range() unrolling
cute.size() returns a CuTeDSL symbol, not a Python int.
range() on a symbol can't iterate — the loop never unrolls.
Now n_kv_tiles is computed in __init__ as s_k // 128 (Python int).
2026-05-22 17:50:07 +00:00
82c46d438e Option 2: Python range() with Int32(kt) for TMA GMEM coord
cutlass.range traces once - kv_coord/kt are trace-time values,
not runtime loop-carried state. Python range() fully unrolls at
trace time, emitting distinct Int32(k) constants per iteration.
Int32(1) hardcoded already proved TMA CAN load from tile 1.
2026-05-22 17:47:43 +00:00
06a0f5fc53 Add example5: use cutlass.range induction variable as TMA GMEM coord 2026-05-22 17:47:10 +00:00
fc45d1cdda README: add fire_b200_test docs, update multi-tile blocker with real findings 2026-05-22 17:41:23 +00:00
97a0e3e045 Clean up debug prints, set kv_coord as Int32(0)
Key findings to relay to CUTLASS LLM:
- kv_coord=Int32(1) hardcode CHANGES the output (TMA CAN load from different tiles)
- kv_coord=Int32(0) + kv_coord += 1 does NOT increment at runtime
  (all multi-tile outputs identical to kv_coord=0)
- kv_coord=0 (plain Python int) also doesn't work
- Pipeline handle .count doesn't work either
- The TMA GMEM tile coordinate must be dynamic at kernel runtime,
  but CuTeDSL appears to constant-fold or not propagate the increment
2026-05-22 17:39:27 +00:00
bea7ab617f DEBUG: try plain Python int kv_coord (like CUTLASS ref) 2026-05-22 17:34:30 +00:00
f3c3186016 DEBUG: hardcode kv_coord=1 to test if TMA uses it 2026-05-22 17:32:53 +00:00
e45b533a9e DEBUG: try K slice (None,0,None,0) keeping mode 2 free 2026-05-22 17:30:06 +00:00
46ded465da DEBUG: print tBgK/tVgV shapes before/after slice 2026-05-22 17:28:45 +00:00
62bbc5ce42 Stage C: manual kv_coord + correct K GMEM slice + O rescale fence
Key fixes:
1. GMEM tile coord: manual Int32 kv_coord (not kvh.count)
2. K GMEM slice: (None,None,0,0) keeps mode 1 free (GMEM iter)
3. V GMEM slice: (None,0,None,0) keeps mode 2 free (GMEM iter)
4. Add fence_view_async_tmem_load before O rescale for visibility
2026-05-22 17:26:56 +00:00
da6f211c44 Add example4: manual kv_coord Int32 for GMEM tile indexing
Pipeline handle .count is NOT a GMEM coordinate — it's opaque pipeline
state. CUTLASS FMHA reference uses a manual Int32 counter (kv_coord)
for GMEM and handle.index for SMEM. .count is never used as a
coordinate anywhere in the reference.
2026-05-22 17:24:38 +00:00
4faadff452 README: add test harness instructions 2026-05-22 17:09:53 +00:00
aa10f234da run_test.sh: SIGKILL all children of screen session on cleanup
Deadlocked GPU processes ignore SIGHUP from screen -X quit.
Now kills the entire process group with SIGKILL, plus a catch-all
pkill for any python test_ processes.
2026-05-22 17:08:12 +00:00
9af2ffe6f2 Add check_log.sh convenience script 2026-05-22 17:07:23 +00:00
34ee991835 Fix quoting in run_test.sh 2026-05-22 17:06:00 +00:00
7329dc39a6 Add run_test.sh harness (screen + log) 2026-05-22 17:05:43 +00:00
6e1a150d69 FIX: only slice GMEM tensors (SMEM already 2D from tma_partition) 2026-05-22 16:57:31 +00:00
594878101d FIX: consistent GMEM/SMEM slicing for K and V TMA partitions
Both GMEM and SMEM sides must be sliced to the same rank for cute.copy.
K (QK MMA B-partition): slice [(None,None,0,0)] keeps modes 0,1
  - mode 1 = GMEM iteration, indexed by kvh.count
V (PV MMA B-partition): slice [(None,0,None,0)] keeps modes 0,2
  - mode 2 = GMEM iteration, indexed by kvh.count
Q: only 1 tile, (None,0,None,0) hardcode is fine.
2026-05-22 16:56:38 +00:00
0065c0180a FIX: keep GMEM iteration dimension FREE in TMA K/V partition slices
Root cause of multi-tile failure: (None,0,None,0) slice hardcodes the
GMEM tile dimension to 0, so TMA always loads from tile 0 regardless
of kvh.count. K from QK MMA has GMEM iter at mode 1, V from PV MMA
has it at mode 2 (different layouts: K,D,L vs D,K,L).

Fix follows CUTLASS FMHA reference:
- K: tBgK[(None,None,None,0)] + tBgK[(None, kvh.count, None)]
- V: tVgV[(None,0,None,0)] + tVgV[(None, kvh.count)]
2026-05-22 16:51:57 +00:00
b7c3b6613c Add diagnostic test for multi-tile TMA pipeline (identity softmax) 2026-05-22 16:47:08 +00:00
3712fda6ce FIX: acc_scale was double-multiplying by scale_log2
row_max is already in log2(scaled) space (S * scale_log2), so
old_row_max - row_max_safe is the correct exponent for exp2.
The old code computed exp2(scale_log2 * (old_row_max - row_max_safe))
which is exp2(scale_log2^2 * (old_max_S - new_max_S)) — wrong.
2026-05-22 16:42:45 +00:00
8ac35a32d7 Stage C: integrate example3 multi-tile fixes into unit test
- Combined K+V barrier (one acquire per kt, kvh.count == kt)
- O rescale for kt > 0 (online softmax O correction)
- final_o_bar sync (MMA signals before producer_tail)
- s_k as constructor param (compile-time for V layout)
- kv_tx_bytes covers both K and V transfers
- Test covers n=128, 256, 512, 1024
2026-05-22 16:39:45 +00:00
4fb606d41a README + MEMORY: update Stage C status to single-tile only, document multi-tile blocker
- Stage C works for n=128 (0.993) but multi-tile (n>128) is broken
- Root cause: tBgK slice hardcodes GMEM iteration to tile 0
- CuTeDSL TMA copy doesn't accept Python int as tile index
- Mike's combined K+V barrier fix compiles but deadlocks at runtime
- Fallback: kh.count // 2 (untested)
2026-05-22 16:32:31 +00:00
3ddee26eff FMHA Stage-C multi-tile: combined K+V barrier, final_o_bar, acc_pipe producer
Key changes from Mike:
1. Combined K+V TMA barrier: one acquire per kt, both cute.copys share
   kvh.barrier. kvh.count naturally == kt (no interleaving problem).
   tx_count = K_bytes + V_bytes. Also fixes the sK[0]/sV[1] slot quirk.
2. final_o_bar NamedBarrier: MMA .arrive() after acc_pipe.producer_tail;
   softmax .arrive_and_wait() before reading O for normalize. Prevents
   softmax racing MMA's PV[N-1] on the final O read.
3. acc_pipe producer in MMA: producer_acquire before loop, commit+advance
   after loop, producer_tail after. Consumer in epilogue as before.
4. O rescale re-enabled for kt>0 with acc_scale before softmax_done_bar.
2026-05-22 16:23:36 +00:00
d83434384e restore tBgK to kh.count indexing (single-tile working), add TODO for multi-tile
CuTeDSL TMA copy API doesn't support dynamic GMEM tile indexing.
kh.count works for single tile. For multi-tile, need to either:
1. Map pipeline count to tile index (kh.count // 2 for interleaved K/V)
2. Separate K and V into non-interleaved TMA loops
3. Use gK/gV layouts that iterate naturally with pipeline count

This is the architectural blocker for multi-tile FMHA.
2026-05-22 15:54:03 +00:00
493d8e817d FIX: use unsliced tBgK with (None, kt, None, 0) for proper GMEM tile indexing
The pre-slice (None,0,None,0) hardcoded GMEM iteration to tile 0.
Instead, keep the original tBgK and index with (None, kt, None, 0)
inside the TMA loop, where kt selects the correct GMEM tile.
This preserves 2D rank matching with the SMEM tensor.
2026-05-22 15:52:56 +00:00
d3031eefd8 CRITICAL FIX: keep GMEM iteration dim free in tBgK/tVgV slice
The slice (None,0,None,0) was hardcoding the GMEM iteration dim to 0,
meaning TMA always loaded K/V from tile 0 regardless of kt.
Changed to (None,None,None,0) to keep gmem_iter free,
then index with (None, kt, None) in the TMA copy loop.

This is the root cause of multi-tile failure: TMA was always reading
the first 128 tokens for ALL KV tiles.
2026-05-22 15:52:06 +00:00