diff --git a/tests/fmha_v3_stage_c_example6.py b/tests/fmha_v3_stage_c_example6.py index ec9db606..b56738af 100644 --- a/tests/fmha_v3_stage_c_example6.py +++ b/tests/fmha_v3_stage_c_example6.py @@ -183,7 +183,7 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,None,0,0)] + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,0,None,0)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -207,20 +207,20 @@ class FmhaV3StageCMulti: pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ===== TMA LOAD warp ===== - # GMEM tile coordinate: use Python range() so the JIT traces each - # iteration separately with concrete kt values. cutlass.range generates - # an scf.for where the induction variable gets constant-folded into - # the TMA descriptor (always 0 at runtime). Plain range() unrolls at - # trace time, giving each iteration a distinct static coordinate. + # Combined K+V barrier pattern matching working test_fmha_v3_diag.py. + # K uses (None,None,0,0) pre-slice to keep GMEM tile dim free. + # V uses (None,0,None,0) — GMEM tile dim accessible via kv_coord. if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset(); pk = kvp.try_acquire() - for kt in range(n_kv_tiles): + kv_coord = Int32(0 + 0) + for kt in cutlass.range(n_kv_tiles, unroll=1): kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, Int32(kt))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + kv_coord += 1 pk = cutlass.Boolean(1) kvp.tail() @@ -336,7 +336,9 @@ class FmhaV3StageCMulti: row_max_safe = Float32(0.0) # acc_scale used for both row_sum rescale and O rescale. - acc_scale_ = scale_log2 * (old_row_max - row_max_safe) + # row_max is already in scaled domain (S * scale_log2), so + # acc_scale = exp2(old_max - new_max) with no extra scale_log2. + acc_scale_ = old_row_max - row_max_safe acc_scale = cute.math.exp2(acc_scale_, fastmath=True) if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0) @@ -346,12 +348,12 @@ class FmhaV3StageCMulti: # store BF16 P through the FP32-backed register bridge. rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) - minus_row_max_scale = (Float32(0.0) - row_max_safe) * scale_log2 + minus_row_max = Float32(0.0) - row_max_safe rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) for j in range(frg_cnt): for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max_scale + tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) row_sum = row_sum + tTMEM_LOADrS_frg[k, j] s_vec = tTMEM_LOADrS_frg[None, j].load()