From e33f5824369d3ed859e353a13df7592edf2bb467 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 03:36:40 +0000 Subject: [PATCH] D1: P store uses tOrP0.layout (PV A-fragment TMEM layout) --- dsv4/kernels/attention/fmha.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index b10a4c1e..bfa4ad0c 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -225,15 +225,17 @@ class FmhaKernel: # P store atoms p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - # P store must use the PV A-fragment layout (p_tmem_s.outer), not the QK C-fragment layout. - # At hd=64 these match by coincidence; at hd>64 they diverge, causing garbage PV output. - tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, p_tmem_s.outer) + # P store must use the PV A-fragment layout so PV reads correct TMEM columns. + # Use tOrP0's layout (PV A-fragment) for the store target, not QK C-fragment composition. + tStP0 = cute.make_tensor(tOrP0.iterator, tOrP0.layout) tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) thr_store = tiled_tmem_store.get_slice(sfw_idx) tTMEM_STOREtP = thr_store.partition_D(tStP0) - tScP = cute.make_tensor(tScS.iterator, p_tmem_s.outer) - tTMEM_STOREcP = thr_store.partition_S(tScP) + # tScP: coordinate tensor for P store — use PV A-fragment layout + cP = cute.make_identity_tensor((self.pv_mma_tiler[0], p_cols_fp32)) + tOcP = pv_thr.partition_A(cP) # partition using PV thread slice + tTMEM_STOREcP = thr_store.partition_S(cute.make_tensor(tOcP.iterator, tOrP0.layout)) row_max = -Float32.inf row_sum = Float32(0.0)