SMEM-P: implement two-phase softmax with normalization before SMEM write
This commit is contained in:
@@ -348,12 +348,24 @@ class FmhaKernel:
|
||||
minus_row_max = Float32(0.0) - row_max_safe
|
||||
|
||||
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
|
||||
# Phase 1: Compute exp values and accumulate row_sum
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
|
||||
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
# Compute inverse row sum for normalization
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
# Phase 2: Normalize P values and write to SMEM (if using SMEM-P)
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
# Get normalized P value
|
||||
p_val = tTMEM_LOADrS_frg[k, j] * inv_row_sum
|
||||
|
||||
# If using SMEM-P, write P value directly to SMEM
|
||||
if self.use_smem_p:
|
||||
# Get QK coordinate for this position
|
||||
qk_coord = tTMEM_LOADcS_frg[k, j]
|
||||
@@ -372,11 +384,8 @@ class FmhaKernel:
|
||||
n2 = n_local // 64
|
||||
pv_coord = ((m_local, n0), 0, (n1, n2), 0)
|
||||
|
||||
# DEBUG: Write pattern based on fragment indices (k,j)
|
||||
# If coordinates wrong, this pattern might work better
|
||||
pattern_val = Float32(k) + Float32(j) * Float32(32.0)
|
||||
p_val_bf16 = pattern_val.to(self.q_dtype)
|
||||
# Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
|
||||
# Write normalized P value
|
||||
p_val_bf16 = p_val.to(self.q_dtype)
|
||||
sP[pv_coord] = p_val_bf16 # Tensor indexing
|
||||
|
||||
# DEBUG: Print first few coordinates to verify mapping
|
||||
@@ -388,17 +397,9 @@ class FmhaKernel:
|
||||
print(f"[SMEM-P DEBUG] offset = {offset}")
|
||||
except:
|
||||
print(f"[SMEM-P DEBUG] crd2idx not available")
|
||||
|
||||
# DEBUG: Also write pattern based on fragment indices (k,j)
|
||||
# If coordinates wrong, this pattern might work better
|
||||
pattern_val = Float32(k) + Float32(j) * Float32(32.0)
|
||||
p_val_bf16 = pattern_val.to(self.q_dtype)
|
||||
# Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
|
||||
sP[pv_coord] = p_val_bf16 # Tensor indexing
|
||||
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
else:
|
||||
# For TMEM-P, store normalized P to register buffer
|
||||
rP_bf16_frg[k, j] = p_val.to(self.q_dtype)
|
||||
|
||||
if not self.use_smem_p:
|
||||
# TMEM-P: store P to TMEM via register bridge
|
||||
|
||||
Reference in New Issue
Block a user