From 902bda5c31495cb63b40342039d162263ee42b97 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 03:40:10 +0000 Subject: [PATCH] D1: align P store and PV A-fragment layouts via tP Key insight: tP (PV A-fragment base) used p_tmem_s.outer layout, but P store used QK C-fragment composition layout. These diverge at hd>64. Fix: tP now uses the same QK C-fragment composition layout (tStP_layout) as the P store. PV A-fragment is derived from tP, so it automatically uses the same layout. No double-offset since tP includes P offset. --- dsv4/kernels/attention/fmha.py | 42 +++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 37a18301..fa0c83d6 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -157,12 +157,18 @@ class FmhaKernel: tOtO = pv_thr.make_fragment_C(pv_as) tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) - tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + # P tensor for PV A-fragment: use QK C-fragment composition layout so P store and PV read agree + p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width + tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + # PV A-fragment offset: p0_offset in BF16 columns = p0_offset * (FP32_width / BF16_width) + tP = cute.make_tensor( + tStS.iterator + self.tmem_p0_offset * (self.qk_acc_dtype.width // self.q_dtype.width), + tStP_layout, + ) tOrP_base = pv_thr.make_fragment_A(tP) tOrP = tOrP_base[(None,None,None,0)] - tOrP0 = cute.make_tensor( - tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, - tOrP.layout) + # tP already starts at P offset, no additional offset needed + tOrP0 = tOrP tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) @@ -225,17 +231,15 @@ class FmhaKernel: # P store atoms p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - # P store: use PV A-fragment layout (tOrP0) so PV GEMM reads correct TMEM columns. - # Store as BF16 (matching PV A-fragment), not FP32 (which causes layout mismatch at hd>64). - tStP0_bf16 = cute.make_tensor(tOrP0.iterator, tOrP0.layout) - tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.q_dtype) - tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0_bf16) + # P store uses QK C-fragment composition layout (same layout as tP/PV A-fragment) + tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout) + tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) thr_store = tiled_tmem_store.get_slice(sfw_idx) - tTMEM_STOREtP = thr_store.partition_D(tStP0_bf16) - # Coordinate tensor for P store - cP = cute.make_identity_tensor((self.pv_mma_tiler[0], p_cols_fp32)) - tOcP = pv_thr.partition_A(cP) - tTMEM_STOREcP = thr_store.partition_S(cute.make_tensor(tOcP.iterator, tOrP0.layout)) + tTMEM_STOREtP = thr_store.partition_D(tStP0) + tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + tScP = cute.make_tensor(tScS.iterator, tScP_layout) + tTMEM_STOREcP = thr_store.partition_S(tScP) row_max = -Float32.inf row_sum = Float32(0.0) @@ -290,10 +294,12 @@ class FmhaKernel: acc_scale = Float32(0.0) row_sum *= acc_scale - # Store P to TMEM using PV A-fragment layout (BF16) - rP_bf16_reg = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.q_dtype) + # Store P to TMEM (FP32, using QK C-fragment composition layout) + rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) + rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) minus_row_max = Float32(0.0) - row_max_safe - rP_bf16_frg = cute.logical_divide(rP_bf16_reg, cute.make_layout(frg_tile)) + + rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) for j in range(frg_cnt): for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max @@ -302,7 +308,7 @@ class FmhaKernel: s_vec = tTMEM_LOADrS_frg[None, j].load() rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) - cute.copy(tiled_tmem_store, rP_bf16_reg, tTMEM_STOREtP) + cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)