diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index bfa4ad0c..37a18301 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -225,16 +225,16 @@ class FmhaKernel: # P store atoms p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - # P store must use the PV A-fragment layout so PV reads correct TMEM columns. - # Use tOrP0's layout (PV A-fragment) for the store target, not QK C-fragment composition. - tStP0 = cute.make_tensor(tOrP0.iterator, tOrP0.layout) - tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) - tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) + # P store: use PV A-fragment layout (tOrP0) so PV GEMM reads correct TMEM columns. + # Store as BF16 (matching PV A-fragment), not FP32 (which causes layout mismatch at hd>64). + tStP0_bf16 = cute.make_tensor(tOrP0.iterator, tOrP0.layout) + tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.q_dtype) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0_bf16) thr_store = tiled_tmem_store.get_slice(sfw_idx) - tTMEM_STOREtP = thr_store.partition_D(tStP0) - # tScP: coordinate tensor for P store — use PV A-fragment layout + tTMEM_STOREtP = thr_store.partition_D(tStP0_bf16) + # Coordinate tensor for P store cP = cute.make_identity_tensor((self.pv_mma_tiler[0], p_cols_fp32)) - tOcP = pv_thr.partition_A(cP) # partition using PV thread slice + tOcP = pv_thr.partition_A(cP) tTMEM_STOREcP = thr_store.partition_S(cute.make_tensor(tOcP.iterator, tOrP0.layout)) row_max = -Float32.inf @@ -290,11 +290,10 @@ class FmhaKernel: acc_scale = Float32(0.0) row_sum *= acc_scale - rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) - rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) + # Store P to TMEM using PV A-fragment layout (BF16) + rP_bf16_reg = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.q_dtype) minus_row_max = Float32(0.0) - row_max_safe - - rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) + rP_bf16_frg = cute.logical_divide(rP_bf16_reg, cute.make_layout(frg_tile)) for j in range(frg_cnt): for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max @@ -303,7 +302,7 @@ class FmhaKernel: s_vec = tTMEM_LOADrS_frg[None, j].load() rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) - cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) + cute.copy(tiled_tmem_store, rP_bf16_reg, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)