fix vectorize issue: remove vectorize from exp2 pass, add row_sum accumulation
- Remove vectorize=True from exp2 computation loop (carry variable) - Add row_sum accumulation from P values in exp2 pass - Compute row_max via fmax in separate pass
This commit is contained in:
@@ -247,23 +247,18 @@ class FmhaV3StageC:
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Compute row_max and row_sum via element-wise processing
|
||||
old_row_max = row_max
|
||||
|
||||
# P = exp2((S - new_max) * scale_log2) via register bridge
|
||||
# First: find row_max from S (element-wise, before exp2)
|
||||
# Compute row_max from S (element-wise, before exp2)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
|
||||
for j in range(frg_cnt):
|
||||
s_val = tTMEM_LOADrS_frg[k, j] * scale_log2
|
||||
row_max = cute.arch.fmax(row_max, s_val)
|
||||
for j in range(frg_cnt):
|
||||
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
|
||||
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
|
||||
|
||||
row_max_safe = row_max
|
||||
if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0)
|
||||
|
||||
# Scale existing row_sum: row_sum *= exp2((old_max - new_max) * scale_log2)
|
||||
# Scale existing row_sum
|
||||
acc_scale_ = scale_log2 * (old_row_max - row_max_safe)
|
||||
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
|
||||
if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0)
|
||||
@@ -276,7 +271,7 @@ class FmhaV3StageC:
|
||||
|
||||
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max_scale
|
||||
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
|
||||
Reference in New Issue
Block a user