CRITICAL FIX: remove extra scale_log2 in softmax (minus_row_max and acc_scale)

This commit is contained in:
2026-05-22 18:52:58 +00:00
parent fe58d07b8c
commit 4fef047f5c

View File

@@ -183,7 +183,7 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,None,0,0)]
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -207,20 +207,20 @@ class FmhaV3StageCMulti:
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp =====
# GMEM tile coordinate: use Python range() so the JIT traces each
# iteration separately with concrete kt values. cutlass.range generates
# an scf.for where the induction variable gets constant-folded into
# the TMA descriptor (always 0 at runtime). Plain range() unrolls at
# trace time, giving each iteration a distinct static coordinate.
# Combined K+V barrier pattern matching working test_fmha_v3_diag.py.
# K uses (None,None,0,0) pre-slice to keep GMEM tile dim free.
# V uses (None,0,None,0) — GMEM tile dim accessible via kv_coord.
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in range(n_kv_tiles):
kv_coord = Int32(0 + 0)
for kt in cutlass.range(n_kv_tiles, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, Int32(kt))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kv_coord += 1
pk = cutlass.Boolean(1)
kvp.tail()
@@ -336,7 +336,9 @@ class FmhaV3StageCMulti:
row_max_safe = Float32(0.0)
# acc_scale used for both row_sum rescale and O rescale.
acc_scale_ = scale_log2 * (old_row_max - row_max_safe)
# row_max is already in scaled domain (S * scale_log2), so
# acc_scale = exp2(old_max - new_max) with no extra scale_log2.
acc_scale_ = old_row_max - row_max_safe
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if old_row_max == -cutlass.Float32.inf:
acc_scale = Float32(0.0)
@@ -346,12 +348,12 @@ class FmhaV3StageCMulti:
# store BF16 P through the FP32-backed register bridge.
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
minus_row_max_scale = (Float32(0.0) - row_max_safe) * scale_log2
minus_row_max = Float32(0.0) - row_max_safe
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max_scale
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()