BREAKTHROUGH: cosine 0.993 for n=128! PV-partitioned P row sum works.
C9 fix: instead of using QK-partitioned row_sum (which maps to wrong PV rows), read P from TMEM using PV partition and sum via .reduce(ADD). QK: thread N owns row N//4, PV: thread N owns row N. Reading P via PV partition gives each thread its correct row P values. n=128: cosine 0.993 (was 0.514) n=256: cosine 0.725 (C6 still broken for multi-tile) n=384: cosine 0.676 (same C6 issue) Remaining: C6 O-rescale for multi-tile needs same PV-partitioned fix. Small accuracy gap (0.993 vs 0.999) likely from BF16 P store/load round-trip.
This commit is contained in:
@@ -301,8 +301,31 @@ class FmhaV3Softmax:
|
||||
acc_scale = cute.math.exp2(scale * (old_row_max - row_max_safe), fastmath=True)
|
||||
|
||||
# --- C6: Rescale O in TMEM (load O, multiply by acc_scale, store O) ---
|
||||
# acc_scale belongs to QK row (N//4), but O rows are in PV partition (N).
|
||||
# Store acc_scale to vector by QK row, read by PV row.
|
||||
if kt > 0:
|
||||
pv_done_bar.arrive_and_wait()
|
||||
|
||||
# Store acc_scale to vector indexed by QK logical row
|
||||
qk_row_c6 = tTMEM_LOADcS[0][0]
|
||||
thr_vs_c6 = tiled_tmem_store_vec.get_slice(qk_row_c6)
|
||||
tVStore_c6 = thr_vs_c6.partition_D(tStS_vec)
|
||||
tVStoreSrc_c6 = thr_vs_c6.partition_S(tScS_vec)
|
||||
tVStoreRmem_c6 = cute.make_rmem_tensor(tVStoreSrc_c6.shape, self.qk_acc_dtype)
|
||||
tVStoreRmem_c6[0] = acc_scale
|
||||
cute.copy(tiled_tmem_store_vec, tVStoreRmem_c6, tVStore_c6)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Read acc_scale from vector indexed by PV logical row
|
||||
pv_row_c6 = tTMEM_LOADcO[0][0]
|
||||
thr_vl_c6 = tiled_tmem_load_vec.get_slice(pv_row_c6)
|
||||
tVLoad_c6 = thr_vl_c6.partition_S(tStS_vec)
|
||||
tVLoadDst_c6 = thr_vl_c6.partition_D(tScS_vec)
|
||||
tVLoadRmem_c6 = cute.make_rmem_tensor(tVLoadDst_c6.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load_vec, tVLoad_c6, tVLoadRmem_c6)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
acc_scale_pv = tVLoadRmem_c6[0]
|
||||
|
||||
tTMrO = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)
|
||||
for i in range(o_col_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
@@ -312,7 +335,7 @@ class FmhaV3Softmax:
|
||||
tTMEM_STOREtO_i = cute.make_tensor(tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout)
|
||||
cute.copy(o_tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i)
|
||||
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[j] = tTMrO_i[j] * acc_scale
|
||||
tTMrO_i[j] = tTMrO_i[j] * acc_scale_pv
|
||||
cute.copy(o_tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
@@ -376,19 +399,53 @@ class FmhaV3Softmax:
|
||||
|
||||
# --- C9: Final normalization via O TMEM rescale ---
|
||||
pv_done_bar.arrive_and_wait()
|
||||
# Store final row_sum to TMEM vector (per-row, using QK partition)
|
||||
tTMEM_STORE_VECrS_final = cute.make_rmem_tensor(tTMEM_STORE_VECcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORE_VECrS_final[0] = row_sum
|
||||
cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS_final, tTMEM_STORE_VECtS)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Read vector back: per-row row_sum using QK partition coordinates
|
||||
tTMEM_LOAD_VECrS = cute.make_rmem_tensor(tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS, tTMEM_LOAD_VECrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
inv_row_sum = cutlass.Float32(1.0) / tTMEM_LOAD_VECrS[0]
|
||||
# Compute inv_row_sum from P in TMEM using PV partition.
|
||||
# P was stored by softmax loop into TMEM at offset tmem_p0_offset.
|
||||
# PV partition maps thread N to PV row N, so reading P via PV partition
|
||||
# gives the correct per-row P values to sum.
|
||||
# This avoids the QK→PV row mapping mismatch (QK: N->N//4, PV: N->N).
|
||||
|
||||
# P is stored as BF16 in TMEM at tmem_p0_offset.
|
||||
# We need to read it via PV TMEM load and sum the values.
|
||||
# P has shape (128, HEAD_DIM//2) in FP32 columns (64 BF16 = 32 FP32 cols).
|
||||
# Use the P TMEM load partition (PV A-fragment read).
|
||||
|
||||
# Actually, P was stored via QK C-fragment store (St32x32bOp Repetition(32)).
|
||||
# To read it via PV partition, we need a PV-partitioned load from the P region.
|
||||
# Let's use the same o_tiled_tmem_load but pointed at P's TMEM offset.
|
||||
|
||||
# P occupies TMEM columns [tmem_p0_offset, tmem_p0_offset + p_cols_fp32)
|
||||
# In the PV C-fragment, P is the A-fragment. We can use tOrP0's layout.
|
||||
# tOrP0 was set up with offset for PV MMA read.
|
||||
|
||||
# Simpler: sum O across columns to get unnormalized row sum, then normalize.
|
||||
# For V=identity, O = P@V = sum(P per row). So O.sum(dim=-1) = row_sum.
|
||||
# For arbitrary V, O = P@V. O.sum(dim=-1) = sum_j(P@V)[j] = sum_j(sum_i P[i]*V[i,j])
|
||||
# This is NOT sum(P). So this trick only works for V=identity.
|
||||
|
||||
# Correct approach: read P from TMEM, sum it per PV row.
|
||||
# P is at TMEM offset tmem_p0_offset, stored as BF16 with St32x32bOp.
|
||||
# P shape in TMEM: 128 rows x (HEAD_DIM BF16 = 32 FP32 cols)
|
||||
# We can read P using Ld32x32bOp(Repetition(corr_tile_size)) via PV O-partition.
|
||||
|
||||
# Use PV O TMEM load to read from P region instead of O region
|
||||
p_col_tiles = p_cols_fp32 // corr_tile_size # 32 // 16 = 2
|
||||
pv_row_sum = cutlass.Float32(0.0)
|
||||
for i in range(p_col_tiles):
|
||||
# Read P tile from TMEM at P offset (not O offset)
|
||||
tTMEM_LOADtP_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + (self.tmem_p0_offset - self.tmem_o0_offset) + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout)
|
||||
tTMrP_i = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.qk_acc_dtype)
|
||||
cute.copy(o_tiled_tmem_load, tTMEM_LOADtP_i, tTMrP_i)
|
||||
# Use .reduce(SUM) instead of scalar accumulation (vectorizer can't handle scalar in vectorized loop)
|
||||
tile_p_sum = tTMrP_i.load().reduce(cute.ReductionOp.ADD, cutlass.Float32(0.0), 0)
|
||||
pv_row_sum = pv_row_sum + tile_p_sum
|
||||
|
||||
inv_row_sum = cutlass.Float32(1.0) / pv_row_sum
|
||||
|
||||
# Normalize O in TMEM
|
||||
# Normalize O in TMEM using PV-correct inv_row_sum
|
||||
tTMrO_final = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)
|
||||
for i in range(o_col_tiles):
|
||||
tTMrO_i_ = tTMrO_final[None, i]
|
||||
@@ -452,34 +509,3 @@ if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
|
||||
def test():
|
||||
import math
|
||||
torch.manual_seed(42)
|
||||
for n in [128, 256, 384]:
|
||||
m, hd = 128, HEAD_DIM
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device="cuda")
|
||||
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device="cuda")
|
||||
v = torch.randn(n, hd, dtype=torch.bfloat16, device="cuda")
|
||||
v_kernel = v.unsqueeze(-1)
|
||||
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device="cuda")
|
||||
qf = q[:,:,0].float(); kf = k[:,:,0].float()
|
||||
attn = qf @ kf.T / math.sqrt(hd)
|
||||
ref = torch.softmax(attn, dim=-1) @ v.float()
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = FmhaV3Softmax()
|
||||
print(f"n={n}: Compiling...", flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print(f"n={n}: Running...", flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print(f"FMHA softmax n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
Reference in New Issue
Block a user