From 941bcae8e170afdf70dbc1e8d3eafd4334e205d9 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 09:27:36 +0000 Subject: [PATCH] softmax: element-wise row_max computation instead of .reduce() The .reduce() on the C-fragment gives global max across all rows, not per-row max. Compute row_max element-wise from S values before the exp2 pass. Also accumulate row_sum in the exp2 pass. --- tests/unit/test_fmha_v3_stage_c_full.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 209053f6..6f086071 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -247,9 +247,19 @@ class FmhaV3StageC: cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) cute.arch.fence_view_async_tmem_load() - # Update row_max + # Compute row_max and row_sum via element-wise processing old_row_max = row_max - row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0) + + # P = exp2((S - new_max) * scale_log2) via register bridge + # First: find row_max from S (element-wise, before exp2) + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True): + for j in range(frg_cnt): + s_val = tTMEM_LOADrS_frg[k, j] * scale_log2 + if s_val > row_max: row_max = s_val + row_max_safe = row_max if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0) @@ -259,24 +269,19 @@ class FmhaV3StageC: if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0) row_sum *= acc_scale - # P = exp2((S - new_max) * scale_log2) via register bridge + # Second pass: compute P and accumulate row_sum rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) minus_row_max_scale = (Float32(0.0) - row_max_safe) * scale_log2 - frg_cnt = 4 - frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt - tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) for j in range(frg_cnt): for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True): tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max_scale tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) + row_sum = row_sum + tTMEM_LOADrS_frg[k, j] s_vec = tTMEM_LOADrS_frg[None, j].load() rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) - # Accumulate row_sum from P values - for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - row_sum = row_sum + tTMEM_LOADrS_frg[k, j] cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store()