SMEM-P: implement two-phase softmax with normalization before SMEM write

This commit is contained in:
2026-05-23 20:08:29 +00:00
parent e4e63b0331
commit 0bcdfa1dc9

View File

@@ -348,12 +348,24 @@ class FmhaKernel:
minus_row_max = Float32(0.0) - row_max_safe
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
# Phase 1: Compute exp values and accumulate row_sum
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
# Compute inverse row sum for normalization
inv_row_sum = Float32(1.0) / row_sum
# Phase 2: Normalize P values and write to SMEM (if using SMEM-P)
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
# Get normalized P value
p_val = tTMEM_LOADrS_frg[k, j] * inv_row_sum
# If using SMEM-P, write P value directly to SMEM
if self.use_smem_p:
# Get QK coordinate for this position
qk_coord = tTMEM_LOADcS_frg[k, j]
@@ -372,11 +384,8 @@ class FmhaKernel:
n2 = n_local // 64
pv_coord = ((m_local, n0), 0, (n1, n2), 0)
# DEBUG: Write pattern based on fragment indices (k,j)
# If coordinates wrong, this pattern might work better
pattern_val = Float32(k) + Float32(j) * Float32(32.0)
p_val_bf16 = pattern_val.to(self.q_dtype)
# Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
# Write normalized P value
p_val_bf16 = p_val.to(self.q_dtype)
sP[pv_coord] = p_val_bf16 # Tensor indexing
# DEBUG: Print first few coordinates to verify mapping
@@ -388,17 +397,9 @@ class FmhaKernel:
print(f"[SMEM-P DEBUG] offset = {offset}")
except:
print(f"[SMEM-P DEBUG] crd2idx not available")
# DEBUG: Also write pattern based on fragment indices (k,j)
# If coordinates wrong, this pattern might work better
pattern_val = Float32(k) + Float32(j) * Float32(32.0)
p_val_bf16 = pattern_val.to(self.q_dtype)
# Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
sP[pv_coord] = p_val_bf16 # Tensor indexing
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
else:
# For TMEM-P, store normalized P to register buffer
rP_bf16_frg[k, j] = p_val.to(self.q_dtype)
if not self.use_smem_p:
# TMEM-P: store P to TMEM via register bridge