Fix PV A-operand major mode: K for TMEM-P, a_major for SMEM-P
This commit is contained in:
@@ -116,8 +116,11 @@ class FmhaKernel:
|
||||
self.cta_group, (128, 128), tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
pv_src = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM
|
||||
# When PV reads P from TMEM, P has K-major layout (QK C-fragment alias).
|
||||
# When PV reads P from SMEM, P has Q's major mode (loaded into SMEM).
|
||||
pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype, self.a_major, self.v_major, self.qk_acc_dtype,
|
||||
self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype,
|
||||
self.cta_group, (128, self.pv_n_tile), pv_src,
|
||||
)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
Reference in New Issue
Block a user