fix vectorize issue: remove vectorize from exp2 pass, add row_sum accumulation

- Remove vectorize=True from exp2 computation loop (carry variable)
- Add row_sum accumulation from P values in exp2 pass
- Compute row_max via fmax in separate pass
This commit is contained in:
2026-05-22 09:29:43 +00:00
parent f1687ba3b8
commit ce06478e56

View File

@@ -247,23 +247,18 @@ class FmhaV3StageC:
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# Compute row_max and row_sum via element-wise processing
old_row_max = row_max
# P = exp2((S - new_max) * scale_log2) via register bridge
# First: find row_max from S (element-wise, before exp2)
# Compute row_max from S (element-wise, before exp2)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
for j in range(frg_cnt):
s_val = tTMEM_LOADrS_frg[k, j] * scale_log2
row_max = cute.arch.fmax(row_max, s_val)
for j in range(frg_cnt):
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
row_max_safe = row_max
if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0)
# Scale existing row_sum: row_sum *= exp2((old_max - new_max) * scale_log2)
# Scale existing row_sum
acc_scale_ = scale_log2 * (old_row_max - row_max_safe)
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0)
@@ -276,7 +271,7 @@ class FmhaV3StageC:
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max_scale
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]