From ce06478e56519bf29a86425e96e96581007e206e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 09:29:43 +0000 Subject: [PATCH] fix vectorize issue: remove vectorize from exp2 pass, add row_sum accumulation - Remove vectorize=True from exp2 computation loop (carry variable) - Add row_sum accumulation from P values in exp2 pass - Compute row_max via fmax in separate pass --- tests/unit/test_fmha_v3_stage_c_full.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 91a9f5c8..69108df8 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -247,23 +247,18 @@ class FmhaV3StageC: cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) cute.arch.fence_view_async_tmem_load() - # Compute row_max and row_sum via element-wise processing - old_row_max = row_max - - # P = exp2((S - new_max) * scale_log2) via register bridge - # First: find row_max from S (element-wise, before exp2) + # Compute row_max from S (element-wise, before exp2) frg_cnt = 4 frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) - for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True): - for j in range(frg_cnt): - s_val = tTMEM_LOADrS_frg[k, j] * scale_log2 - row_max = cute.arch.fmax(row_max, s_val) + for j in range(frg_cnt): + for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True): + row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) row_max_safe = row_max if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0) - # Scale existing row_sum: row_sum *= exp2((old_max - new_max) * scale_log2) + # Scale existing row_sum acc_scale_ = scale_log2 * (old_row_max - row_max_safe) acc_scale = cute.math.exp2(acc_scale_, fastmath=True) if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0) @@ -276,7 +271,7 @@ class FmhaV3StageC: rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) for j in range(frg_cnt): - for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max_scale tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) row_sum = row_sum + tTMEM_LOADrS_frg[k, j]