CRITICAL FIX: remove extra scale_log2 in softmax (minus_row_max and acc_scale)
This commit is contained in:
@@ -183,7 +183,7 @@ class FmhaV3StageCMulti:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,None,0,0)]
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -207,20 +207,20 @@ class FmhaV3StageCMulti:
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ===== TMA LOAD warp =====
|
||||
# GMEM tile coordinate: use Python range() so the JIT traces each
|
||||
# iteration separately with concrete kt values. cutlass.range generates
|
||||
# an scf.for where the induction variable gets constant-folded into
|
||||
# the TMA descriptor (always 0 at runtime). Plain range() unrolls at
|
||||
# trace time, giving each iteration a distinct static coordinate.
|
||||
# Combined K+V barrier pattern matching working test_fmha_v3_diag.py.
|
||||
# K uses (None,None,0,0) pre-slice to keep GMEM tile dim free.
|
||||
# V uses (None,0,None,0) — GMEM tile dim accessible via kv_coord.
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset(); pk = kvp.try_acquire()
|
||||
for kt in range(n_kv_tiles):
|
||||
kv_coord = Int32(0 + 0)
|
||||
for kt in cutlass.range(n_kv_tiles, unroll=1):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
cute.copy(tma_k, tBgK[(None, Int32(kt))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
kv_coord += 1
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
@@ -336,7 +336,9 @@ class FmhaV3StageCMulti:
|
||||
row_max_safe = Float32(0.0)
|
||||
|
||||
# acc_scale used for both row_sum rescale and O rescale.
|
||||
acc_scale_ = scale_log2 * (old_row_max - row_max_safe)
|
||||
# row_max is already in scaled domain (S * scale_log2), so
|
||||
# acc_scale = exp2(old_max - new_max) with no extra scale_log2.
|
||||
acc_scale_ = old_row_max - row_max_safe
|
||||
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
|
||||
if old_row_max == -cutlass.Float32.inf:
|
||||
acc_scale = Float32(0.0)
|
||||
@@ -346,12 +348,12 @@ class FmhaV3StageCMulti:
|
||||
# store BF16 P through the FP32-backed register bridge.
|
||||
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
|
||||
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
|
||||
minus_row_max_scale = (Float32(0.0) - row_max_safe) * scale_log2
|
||||
minus_row_max = Float32(0.0) - row_max_safe
|
||||
|
||||
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max_scale
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
|
||||
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
|
||||
Reference in New Issue
Block a user