softmax: element-wise row_max computation instead of .reduce()
The .reduce() on the C-fragment gives global max across all rows, not per-row max. Compute row_max element-wise from S values before the exp2 pass. Also accumulate row_sum in the exp2 pass.
This commit is contained in:
@@ -247,9 +247,19 @@ class FmhaV3StageC:
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Update row_max
|
||||
# Compute row_max and row_sum via element-wise processing
|
||||
old_row_max = row_max
|
||||
row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0)
|
||||
|
||||
# P = exp2((S - new_max) * scale_log2) via register bridge
|
||||
# First: find row_max from S (element-wise, before exp2)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
|
||||
for j in range(frg_cnt):
|
||||
s_val = tTMEM_LOADrS_frg[k, j] * scale_log2
|
||||
if s_val > row_max: row_max = s_val
|
||||
|
||||
row_max_safe = row_max
|
||||
if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0)
|
||||
|
||||
@@ -259,24 +269,19 @@ class FmhaV3StageC:
|
||||
if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0)
|
||||
row_sum *= acc_scale
|
||||
|
||||
# P = exp2((S - new_max) * scale_log2) via register bridge
|
||||
# Second pass: compute P and accumulate row_sum
|
||||
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
|
||||
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
|
||||
minus_row_max_scale = (Float32(0.0) - row_max_safe) * scale_log2
|
||||
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max_scale
|
||||
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
# Accumulate row_sum from P values
|
||||
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
Reference in New Issue
Block a user