From a983a8fb41d6c5215cd2b2dbae91cfa9f81eaee2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 21 May 2026 18:45:30 +0000 Subject: [PATCH] WIP: TMEM vector for per-row row_sum (not yet working) Key finding: the root cause is that each epilogue thread owns MULTIPLE rows in the QK C-fragment, so scalar row_max/row_sum are wrong (global across all rows, not per-row). The V=ones diagnostic confirmed: all 128 threads use the same row_sum (from row 114). Tried: TMEM vector store+load of row_sum (composition(tStS, (128,2))). This is a no-op because both write and read use the SAME QK partition with a scalar row_sum. The vector approach only helps when different partitions are used for write vs read, or when per-row values are stored. Next steps: 1. Need PER-ROW row_max and row_sum, not per-thread scalar 2. The CUTLASS FMHA works because each thread owns exactly 1 row 3. Options: restructure thread layout, or compute per-row values differently 4. The vector must store ALL 128 per-row values, then read per-row in C9 --- tests/unit/test_fmha_v3_softmax.py | 90 ++++++++++++++++++++---------- 1 file changed, 62 insertions(+), 28 deletions(-) diff --git a/tests/unit/test_fmha_v3_softmax.py b/tests/unit/test_fmha_v3_softmax.py index 80637ee4..ef9029bf 100644 --- a/tests/unit/test_fmha_v3_softmax.py +++ b/tests/unit/test_fmha_v3_softmax.py @@ -53,6 +53,7 @@ class FmhaV3Softmax: s_cols = self.qk_mma_tiler[1] # 128 o_after = max(s_cols, p_end) # 128 self.tmem_o0_offset = ((o_after + 31) // 32) * 32 # align to 32 = 128 + self.tmem_vec_offset = 0 # Reuse S region (free after softmax loop) o_cols = find_tmem_tensor_col_offset(tOtO) # footprint of O total = self.tmem_o0_offset + o_cols # Must be multiple of 32 AND power of 2 @@ -70,8 +71,8 @@ class FmhaV3Softmax: self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() - # s_k = cute.size(v, mode=[0]) - # FMHA-style V: reconstruct as (HEAD_DIM, 128, 1) MN-major + # s_k = cute.size(v, mode=[0]) # BROKEN in @cute.jit + # FMHA-style V: reconstruct as (HEAD_DIM, s_k, 1) MN-major v_fmha = cute.make_tensor( v.iterator, cute.make_layout( @@ -232,6 +233,26 @@ class FmhaV3Softmax: tScP = cute.make_tensor(tScS.iterator, tScP_layout) tTMEM_STOREcP = thr_store.partition_S(tScP) + # --- Vector TMEM (per-row row_sum storage, FMHA pattern) --- + # composition(tStS.layout, (128, 2)) = 2 FP32 columns per logical row + # vec[0] = row_sum (final, after loop), vec[1] = unused + # Reuses S TMEM region (offset 0), free after softmax loop writes + + tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) + tStS_vec = cute.make_tensor(tStS.iterator + self.tmem_vec_offset, tStS_vec_layout) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + tmem_store_vec_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype) + tiled_tmem_store_vec = tcgen05.make_tmem_copy(tmem_store_vec_atom, tStS_vec) + thr_tmem_store_vec = tiled_tmem_store_vec.get_slice(sfw_idx) + tTMEM_STORE_VECtS = thr_tmem_store_vec.partition_D(tStS_vec) + tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec) + tmem_load_vec_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype) + tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_vec_atom, tStS_vec) + thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(sfw_idx) + tTMEM_LOAD_VECtS = thr_tmem_load_vec.partition_S(tStS_vec) + tTMEM_LOAD_VECcS = thr_tmem_load_vec.partition_D(tScS_vec) + # --- C6: O TMEM load/store for rescale (correction_rescale pattern) --- corr_tile_size = 16 cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) @@ -354,8 +375,19 @@ class FmhaV3Softmax: # --- C9: Final normalization via O TMEM rescale --- pv_done_bar.arrive_and_wait() - inv_row_sum = cutlass.Float32(1.0) / row_sum + # Store final row_sum to TMEM vector (per-row, using QK partition) + tTMEM_STORE_VECrS_final = cute.make_rmem_tensor(tTMEM_STORE_VECcS.shape, self.qk_acc_dtype) + tTMEM_STORE_VECrS_final[0] = row_sum + cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS_final, tTMEM_STORE_VECtS) + cute.arch.fence_view_async_tmem_store() + # Read vector back: per-row row_sum using QK partition coordinates + tTMEM_LOAD_VECrS = cute.make_rmem_tensor(tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS, tTMEM_LOAD_VECrS) + cute.arch.fence_view_async_tmem_load() + inv_row_sum = cutlass.Float32(1.0) / tTMEM_LOAD_VECrS[0] + + # Normalize O in TMEM tTMrO_final = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype) for i in range(o_col_tiles): tTMrO_i_ = tTMrO_final[None, i] @@ -384,34 +416,36 @@ class FmhaV3Softmax: tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) + def test(): import math torch.manual_seed(42) - n = 128; m = 128; hd = HEAD_DIM - q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device="cuda") - k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device="cuda") - v = torch.randn(n, hd, dtype=torch.bfloat16, device="cuda") - v_kernel = v.unsqueeze(-1) - c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device="cuda") - qf = q[:,:,0].float(); kf = k[:,:,0].float() - attn = qf @ kf.T / math.sqrt(hd) - ref = torch.softmax(attn, dim=-1) @ v.float() - mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) - mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) - mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) - mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - kernel = FmhaV3Softmax() - print("Compiling...", flush=True) - compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) - print("Running n=128...", flush=True) - compiled(mQ, mK, mV, mC, stream) - torch.cuda.synchronize() - out = c[:,:,0].float() - cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() - max_err = (out - ref).abs().max().item() - print(f"n=128: cosine {cos:.6f} max_err {max_err:.6f}") - print(f"out[0,:4]={out[0,:4].tolist()} ref[0,:4]={ref[0,:4].tolist()}") + for n in [128, 256, 384]: + m, hd = 128, HEAD_DIM + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device="cuda") + k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device="cuda") + v = torch.randn(n, hd, dtype=torch.bfloat16, device="cuda") + v_kernel = v.unsqueeze(-1) + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device="cuda") + qf = q[:,:,0].float(); kf = k[:,:,0].float() + attn = qf @ kf.T / math.sqrt(hd) + ref = torch.softmax(attn, dim=-1) @ v.float() + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = FmhaV3Softmax() + print(f"n={n}: Compiling...", flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + print(f"n={n}: tmem: s0={kernel.tmem_s0_offset} p0={kernel.tmem_p0_offset} o0={kernel.tmem_o0_offset} vec={kernel.tmem_vec_offset} alloc={kernel.num_tmem_alloc_cols}", flush=True) + print(f"n={n}: Running...", flush=True) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print(f"FMHA softmax n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True) if __name__ == "__main__": test()