From 0bcdfa1dc9247bd4395631bbdc45710d254b020f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 20:08:29 +0000 Subject: [PATCH] SMEM-P: implement two-phase softmax with normalization before SMEM write --- dsv4/kernels/attention/fmha.py | 35 +++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 87783e49..4afcae64 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -348,12 +348,24 @@ class FmhaKernel: minus_row_max = Float32(0.0) - row_max_safe rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) + # Phase 1: Compute exp values and accumulate row_sum for j in range(frg_cnt): for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) + row_sum = row_sum + tTMEM_LOADrS_frg[k, j] + s_vec = tTMEM_LOADrS_frg[None, j].load() + rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) + + # Compute inverse row sum for normalization + inv_row_sum = Float32(1.0) / row_sum + + # Phase 2: Normalize P values and write to SMEM (if using SMEM-P) + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + # Get normalized P value + p_val = tTMEM_LOADrS_frg[k, j] * inv_row_sum - # If using SMEM-P, write P value directly to SMEM if self.use_smem_p: # Get QK coordinate for this position qk_coord = tTMEM_LOADcS_frg[k, j] @@ -372,11 +384,8 @@ class FmhaKernel: n2 = n_local // 64 pv_coord = ((m_local, n0), 0, (n1, n2), 0) - # DEBUG: Write pattern based on fragment indices (k,j) - # If coordinates wrong, this pattern might work better - pattern_val = Float32(k) + Float32(j) * Float32(32.0) - p_val_bf16 = pattern_val.to(self.q_dtype) - # Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype) + # Write normalized P value + p_val_bf16 = p_val.to(self.q_dtype) sP[pv_coord] = p_val_bf16 # Tensor indexing # DEBUG: Print first few coordinates to verify mapping @@ -388,17 +397,9 @@ class FmhaKernel: print(f"[SMEM-P DEBUG] offset = {offset}") except: print(f"[SMEM-P DEBUG] crd2idx not available") - - # DEBUG: Also write pattern based on fragment indices (k,j) - # If coordinates wrong, this pattern might work better - pattern_val = Float32(k) + Float32(j) * Float32(32.0) - p_val_bf16 = pattern_val.to(self.q_dtype) - # Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype) - sP[pv_coord] = p_val_bf16 # Tensor indexing - - row_sum = row_sum + tTMEM_LOADrS_frg[k, j] - s_vec = tTMEM_LOADrS_frg[None, j].load() - rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) + else: + # For TMEM-P, store normalized P to register buffer + rP_bf16_frg[k, j] = p_val.to(self.q_dtype) if not self.use_smem_p: # TMEM-P: store P to TMEM via register bridge