diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 1f7115de..b6777452 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -13,11 +13,10 @@ import cuda.bindings.driver as cuda import cutlass.torch as ct import math -HEAD_DIM = 64 - class FmhaKernel: - def __init__(self, s_k=128, scale_softmax=None): + def __init__(self, head_dim=64, s_k=128, scale_softmax=None): + self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 self.acc_dtype = Float32; self.qk_acc_dtype = Float32 @@ -27,17 +26,17 @@ class FmhaKernel: self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 self.threads_per_cta = 192; self.num_c_stage = 2 self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2 - self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(HEAD_DIM) + self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) def _setup(self, qk_mma, pv_mma): qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) self.qk_mma_tiler = (128, 128, qk_ik * 4) pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) - self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik)) + self.pv_mma_tiler = (128, self.head_dim, pv_ik * (128 // pv_ik)) self.mma_tiler = self.qk_mma_tiler self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) - self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2]) + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.head_dim, self.qk_mma_tiler[2]) self.c_layout = LayoutEnum.ROW_MAJOR self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) self.num_ab_stage = 1; self.num_acc_stage = 1 @@ -77,14 +76,14 @@ class FmhaKernel: v_fmha = cute.make_tensor( v.iterator, cute.make_layout( - (HEAD_DIM, self.s_k, 1), - stride=(1, HEAD_DIM, HEAD_DIM * self.s_k), + (self.head_dim, self.s_k, 1), + stride=(1, self.head_dim, self.head_dim * self.s_k), ), ) self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() self.c_layout = LayoutEnum.from_tensor(c) qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) - pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.head_dim), tcgen05.OperandSource.TMEM) self._setup(qk_mma, pv_mma) q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) @@ -255,7 +254,7 @@ class FmhaKernel: tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) - n_corr_tiles = HEAD_DIM // corr_tile_size + n_corr_tiles = self.head_dim // corr_tile_size for kt in range(self.n_kv_tiles): si_handle = s_cons.wait_and_advance()