diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index a9008213..47bfee28 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -116,8 +116,11 @@ class FmhaKernel: self.cta_group, (128, 128), tcgen05.OperandSource.SMEM, ) pv_src = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM + # When PV reads P from TMEM, P has K-major layout (QK C-fragment alias). + # When PV reads P from SMEM, P has Q's major mode (loaded into SMEM). + pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K pv_mma = utils.sm100.make_trivial_tiled_mma( - self.q_dtype, self.q_dtype, self.a_major, self.v_major, self.qk_acc_dtype, + self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128, self.pv_n_tile), pv_src, ) self._setup(qk_mma, pv_mma)