softmax: element-wise row_max computation instead of .reduce()

The .reduce() on the C-fragment gives global max across all rows,
not per-row max. Compute row_max element-wise from S values before
the exp2 pass. Also accumulate row_sum in the exp2 pass.
This commit is contained in:
2026-05-22 09:27:36 +00:00
parent 5e51b726ba
commit 941bcae8e1

View File

@@ -247,9 +247,19 @@ class FmhaV3StageC:
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# Update row_max
# Compute row_max and row_sum via element-wise processing
old_row_max = row_max
row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0)
# P = exp2((S - new_max) * scale_log2) via register bridge
# First: find row_max from S (element-wise, before exp2)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
for j in range(frg_cnt):
s_val = tTMEM_LOADrS_frg[k, j] * scale_log2
if s_val > row_max: row_max = s_val
row_max_safe = row_max
if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0)
@@ -259,24 +269,19 @@ class FmhaV3StageC:
if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0)
row_sum *= acc_scale
# P = exp2((S - new_max) * scale_log2) via register bridge
# Second pass: compute P and accumulate row_sum
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
minus_row_max_scale = (Float32(0.0) - row_max_safe) * scale_log2
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max_scale
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
# Accumulate row_sum from P values
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()