diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 57faef35..1ae02921 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -19,6 +19,9 @@ class FmhaKernel: self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 + + self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256 + self.n_pv_tiles = head_dim // self.pv_n_tile self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 @@ -34,10 +37,10 @@ class FmhaKernel: qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) self.qk_mma_tiler = (128, 128, qk_ik * 4) pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) - self.pv_mma_tiler = (128, self.head_dim, pv_ik * (128 // pv_ik)) + self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) self.mma_tiler = self.qk_mma_tiler self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) - self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.head_dim, self.qk_mma_tiler[2]) + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2]) self.c_layout = LayoutEnum.ROW_MAJOR self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) self.num_ab_stage = 1; self.num_acc_stage = 1 @@ -98,7 +101,7 @@ class FmhaKernel: qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM - pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.head_dim), pv_source) + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source) self._setup(qk_mma, pv_mma) q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) @@ -295,7 +298,7 @@ class FmhaKernel: tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) - n_corr_tiles = self.head_dim // corr_tile_size + n_corr_tiles = self.pv_n_tile // corr_tile_size for kt in range(self.n_kv_tiles): si_handle = s_cons.wait_and_advance()