From 61b0501a8b337291450092ca5b2d3b7bd4da7ff1 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 21:50:35 +0000 Subject: [PATCH] Fix test_fmha_v3_stage_c.py: 8-mode TMA indexing + O rescale (from example9) --- tests/unit/test_fmha_v3_stage_c.py | 257 +++++++++++------------------ 1 file changed, 96 insertions(+), 161 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index f917ff9c..dffbabde 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -10,30 +10,31 @@ Two structural rules we had to learn the hard way: (B) Hand-constructed TMEM load/store atoms (Ld32x32bOp + St32x32bOp built independently) DO NOT preserve register tile shape across a round-trip. - A no-op TMEM-load-then-TMEM-store visibly corrupts data. Use the paired - atoms from `utils.sm100.get_tmem_load_op` + `get_smem_store_op` — they - are configured together for the same (mma_tiler, layout, dtype) combo - and the register tile shape lines up. This is what the CUTLASS Blackwell - FMHA reference does in `correction_epilog`. + Use paired atoms (or, as we discovered: independently constructed atoms + DO work if they're built from the SAME `Repetition(N)` count — the + Ld32x32bOp(Rep(16)) + St32x32bOp(Rep(16)) pair preserves the register + tile shape exactly because the atom width matches). This is what the + CUTLASS Blackwell FMHA reference does in `correction_rescale`. + +(C) Multi-tile GMEM indexing: after tma_partition, tBgK/tVgV have 8 modes. + Mode 4 is the GMEM tile iteration axis. Pre-slicing with (None,None,0,0) + silently collapses modes 4-7 to coord 0, so TMA always reads tile 0 + regardless of the coordinate passed. FIX: use 8-None no-op slice to + preserve all modes, then (None, kt) indexing in cute.copy. Kernel structure: 1. Combined K+V pipeline (tx_count = K_bytes + V_bytes; one acquire per kt; K and V share the same barrier slot). SMEM slot via kvh.index, GMEM via - the cutlass.range loop variable. + loop variable kt indexing mode 4 of the 8-mode TMA partition tensor. -2. Reference-style epilogue (TMEM → reg → scale by 1/row_sum → FP32→BF16 in - reg → SMEM via paired atoms → TMA SMEM→GMEM). One pass, no TMEM - round-trip, no `epilogue_tma_store` helper. Inline TMA store + named - barrier sync to substitute for what the helper would have done. +2. Reference-style scaled epilogue: TMEM correction_rescale (O *= 1/row_sum + via paired Ld32x32b + St32x32b atoms), then standard epilogue_tma_store + to send O from TMEM through SMEM to GMEM. -3. Online softmax row_max / row_sum tracking is correct, but the per-tile - in-place TMEM O rescale (multiplying existing O by exp2(old_max - new_max) - before PV[kt]) is currently DISABLED. Fixing that requires applying the - same paired-atom pattern to a separate scratch SMEM buffer and bouncing - PV's accumulator through it, which is substantial work. For now, the - kernel is correct when row_max growth across tiles is mild. Long n with - pronounced max growth will drift; the fix path is well-defined. +3. Per-tile O rescale (multiplying existing O by exp2(old_max - new_max) + before PV[kt]) lives in the softmax warp BEFORE softmax_done_bar.arrive(). + Reuses the same paired-atom pattern as the final normalize. 4. final_o_bar (32 MMA + 128 softmax threads). MMA arrives between acc_pipe.producer_commit and producer_tail; softmax arrives_and_waits @@ -56,6 +57,7 @@ class FmhaV3StageCMulti: def __init__(self, s_k=128, scale_softmax=None): # s_k MUST equal actual sequence length n. self.s_k = s_k + self.n_kv_tiles = s_k // 128 self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 @@ -148,12 +150,9 @@ class FmhaV3StageCMulti: smem = utils.SmemAllocator(); st = smem.allocate(SS) qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - # Combined K+V pipeline: each stage carries BOTH K and V loaded together. kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) - # Final-O sync: MMA arrives between producer_commit and producer_tail; - # softmax arrives_and_waits before reading O for the final normalize. final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id)) acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) @@ -179,11 +178,13 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # TMA source tensor slices: keep the GMEM tile dimension (mode 4) free - # tBgK shape: (1, 1, 1, 1, 2, 1, 1, 1) — 8 modes, mode 4 = kv_tiles - # tVgV shape: (1, 1, 1, 1, 2, 1, 1, 1) — 8 modes, mode 4 = kv_tiles + # CRITICAL: tBgK/tVgV have 8 modes after tma_partition. + # Mode 4 is the GMEM tile iteration axis. Pre-slicing with + # (None,None,0,0) collapses modes 4-7 to 0 — TMA always reads tile 0. + # Fix: 8-None no-op slice preserves all modes; (None, kt) in copy + # addresses mode 4 correctly. tAgQ = tAgQ[(None,0,None,0)] - tBgK = tBgK[(None,None,None,None,None,None,None,None)] # No-op, use full indexing in copy + tBgK = tBgK[(None,None,None,None,None,None,None,None)] tVgV = tVgV[(None,None,None,None,None,None,None,None)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) @@ -208,20 +209,18 @@ class FmhaV3StageCMulti: pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ===== TMA LOAD warp ===== - # GMEM tile coordinate: use the cutlass.range induction variable kt - # directly. CuTeDSL's `cutlass.range` doesn't auto-detect a Python `+=` - # rebinding as a loop-carried iter_args update — the JIT traces the - # body once and captures whatever value `kv_coord` had at trace time, - # so an outer `kv_coord = Int32(0)` plus a `kv_coord += 1` inside the - # loop bakes 0 into every iteration's TMA descriptor at runtime. - # The induction variable IS the loop-carried state, properly tracked. + # kt from cutlass.range indexes mode 4 of the 8-mode TMA tensor, + # which is the GMEM tile iteration axis. Pipeline state (kvh.index) + # selects the SMEM ring buffer slot. if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() - kvp.reset(); pk = kvp.try_acquire() + kvp.reset() + # With 8-mode TMA tensor preserved, kt from cutlass.range + # correctly addresses mode 4 (GMEM tile dim) in cute.copy. for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1): - kvh = kvp.acquire_and_advance(pk) + kvh = kvp.acquire_and_advance() cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) kvp.tail() @@ -229,14 +228,16 @@ class FmhaV3StageCMulti: # ===== MMA warp ===== # One wait per kt; same slot index used for both K (QK) and V (PV). # Release happens AFTER PV — combined slot stays held across QK+PV. + # Note: dropped the try_wait/pk pattern here too, matching the TMA + # warp's simplification. Bare wait_and_advance, no loop-carried pk. if warp_idx == self.mma_warp_id: tmem.wait_for_alloc() qc.reset(); qh = qc.wait_and_advance(); qh.release() kvc.reset() - for kt in range(n_tiles): - kvh = kvc.wait_and_advance() acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) acc_pipe.producer_acquire(acc_st) for kt in range(n_kv_tiles): + kvh = kvc.wait_and_advance() sh = s_prod.acquire_and_advance() qk_mma.set(tcgen05.Field.ACCUMULATE, False) for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): @@ -252,12 +253,6 @@ class FmhaV3StageCMulti: cute.arch.fence_view_async_tmem_store() kvh.release() acc_pipe.producer_commit(acc_st); acc_st.advance() - # Signal softmax FIRST so it can run normalize + epilogue. Then - # wait for the epilogue's consumer-release in producer_tail. - # Reverse order deadlocks: producer_tail blocks waiting for - # consumer release; softmax blocks at final_o_bar waiting for - # MMA arrive; the epilogue (which does the release) is gated - # behind softmax's final_o_bar wait. Cycle. final_o_bar.arrive() acc_pipe.producer_tail(acc_st) @@ -289,43 +284,36 @@ class FmhaV3StageCMulti: tScP = cute.make_tensor(tScS.iterator, tScP_layout) tTMEM_STOREcP = thr_store.partition_S(tScP) - # O rescale setup: same correction_rescale pattern as final normalize. - # Uses paired Ld32x32bOp/St32x32bOp atoms with matching Repetition(16). + # === O rescale path setup (used per-tile AND for final normalize) === corr_tile_size = 16 - cO_corr = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) - tOcO_corr = pv_thr.partition_C(cO_corr) + cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO = pv_thr.partition_C(cO) tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO_corr.layout, cute.make_layout((128, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) - tOcO_i = cute.make_tensor(tOcO_corr.iterator, tOcO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) tmem_load_o_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.acc_dtype, + ) tmem_store_o_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.acc_dtype, + ) tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - tTMEM_LOAD_OtO = thr_tmem_load_o.partition_S(tOtO_i) - tTMEM_LOAD_OcO = thr_tmem_load_o.partition_D(tOcO_i) - tTMEM_STORE_OtO = thr_tmem_store_o.partition_D(tOtO_i) - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOAD_OcO.shape, 128 // corr_tile_size), self.acc_dtype) + tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) + tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) + tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) + n_corr_tiles = HEAD_DIM // corr_tile_size row_max = -Float32.inf row_sum = Float32(0.0) scale_log2 = Float32(self.scale_softmax_log2) - # Per-tile softmax loop. - # Online softmax row_max/row_sum tracking is maintained, but the - # in-place TMEM O rescale (which would multiply existing O by - # exp2(old_max - new_max) before PV[kt]) is DISABLED — this is the - # correctness compromise for hand-paired TMEM atoms not working. - # The fix path is to integrate the rescale into the same paired - # tmem_load/smem_store epilogue pattern we use below for normalize. - # For now: kernel is correct when row_max growth across tiles is - # mild (typical for short n with random data); for very long n - # the missing rescale shows as accuracy drift. + # Per-tile softmax loop with online rescale. for kt in range(n_kv_tiles): si_handle = s_cons.wait_and_advance() @@ -334,7 +322,7 @@ class FmhaV3StageCMulti: cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) cute.arch.fence_view_async_tmem_load() - # Pass 1: update row_max (in log2-domain, fused with scale). + # Pass 1: update row_max in log2-domain. old_row_max = row_max frg_cnt = 4 frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt @@ -347,8 +335,8 @@ class FmhaV3StageCMulti: if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0) - # row_sum rescale (correct even without O rescale — row_sum - # is a register variable, not in TMEM). + # acc_scale = exp2(old_max - new_max). On first tile this is 0 + # (old_max = -inf), so row_sum stays 0 and rescale is skipped. # row_max is already in scaled domain, so no extra scale_log2. acc_scale_ = old_row_max - row_max_safe acc_scale = cute.math.exp2(acc_scale_, fastmath=True) @@ -356,30 +344,8 @@ class FmhaV3StageCMulti: acc_scale = Float32(0.0) row_sum *= acc_scale - # O rescale: multiply existing O by acc_scale = exp2(old_max - new_max) - # Uses the correction_rescale pattern (same paired atoms as final normalize). - # Must happen BEFORE softmax_done_bar.arrive() so MMA's PV[kt] sees rescaled O. - if kt > 0: - for ci in range(HEAD_DIM // corr_tile_size): - tTMrO_i_ = tTMrO[None, ci] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOAD_OtO_i = cute.make_tensor( - tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout - ) - tTMEM_STORE_OtO_i = cute.make_tensor( - tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_i, tTMrO_i) - for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[j] = tTMrO_i[j] * acc_scale - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STORE_OtO_i) - cute.arch.fence_view_async_tmem_store() - # Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum, - # store BF16 P through the FP32-backed register bridge. + # cast to BF16 via FP32-backed register bridge. rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) minus_row_max = Float32(0.0) - row_max_safe @@ -396,86 +362,56 @@ class FmhaV3StageCMulti: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() + # === Per-tile O rescale: O *= acc_scale for kt > 0 === + # Uses the SAME paired-atom pattern as the final normalize. + # Must run BEFORE softmax_done_bar.arrive() so MMA's PV[kt] + # reads the rescaled O. + # Visibility of MMA's PV[kt-1] writes: provided by + # s_cons.wait_and_advance at the top of this iteration, which + # acquires on MMA's S[kt] commit. S[kt] is sequenced after + # PV[kt-1] in MMA's iteration, so PV[kt-1]'s tmem_store_fence + # has been observed by the time we read O here. + if kt > 0: + for i in range(n_corr_tiles): + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, + tTMEM_LOADtO.layout, + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, + tTMEM_STOREtO.layout, + ) + tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO) + cute.arch.fence_view_async_tmem_load() + for k in cutlass.range(cute.size(tTMrO), vectorize=True): + tTMrO[k] = tTMrO[k] * acc_scale + cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i) + cute.arch.fence_view_async_tmem_store() + si_handle.release() softmax_done_bar.arrive() - # === Reference-style scaled epilogue (no TMEM round-trip) === - # - # Pattern (mirrors CUTLASS Blackwell FMHA reference's - # correction_epilog): for each column sub-tile, - # 1. TMEM -> registers via PAIRED tmem_load atom - # 2. scale in registers (1/row_sum) - # 3. FP32 -> BF16 conversion in registers - # 4. registers -> SMEM via PAIRED smem_store atom - # Then TMA SMEM -> GMEM as a separate step. - # - # Critical: the load and store atoms MUST be a matched pair. - # Independently constructed Ld32x32bOp + St32x32bOp atoms (the - # previous code) don't preserve the register tile shape, so even a - # no-op load+store corrupts data. Using utils.blackwell_helpers - # (sm100_utils) gives a paired set keyed to the same epi_subtile. - - # Wait for MMA's PV[N-1] to commit before reading O. + # Wait for MMA's PV[N-1] to commit before reading O for normalize. final_o_bar.arrive_and_wait() - # === O normalization via TMEM load → scale → TMEM store === - # Matches CUTLASS reference's correction_rescale pattern exactly. - - corr_tile_size = 16 - - cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) - tOcO = pv_thr.partition_C(cO) - - tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) - - tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) - tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) - - tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i) - - thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) - thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - - tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) - tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) - tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) - - # 2D register tensor: (frg_shape, n_corr_tiles) - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) - + # === Final O normalization: O *= 1/row_sum === inv_row_sum = Float32(1.0) / row_sum - - for i in range(HEAD_DIM // corr_tile_size): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + for i in range(n_corr_tiles): tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout + tTMEM_LOADtO.iterator + i * corr_tile_size, + tTMEM_LOADtO.layout, ) tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout + tTMEM_STOREtO.iterator + i * corr_tile_size, + tTMEM_STOREtO.layout, ) - - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[j] = tTMrO_i[j] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) - + tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO) + cute.arch.fence_view_async_tmem_load() + for k in cutlass.range(cute.size(tTMrO), vectorize=True): + tTMrO[k] = tTMrO[k] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i) cute.arch.fence_view_async_tmem_store() # Standard epilogue: TMEM → SMEM → GMEM via TMA store. @@ -521,7 +457,6 @@ def test(): mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # Each n requires its own compiled kernel (s_k is compile-time). kernel = FmhaV3StageCMulti(s_k=n) print(f'n={n}: Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)