D1: K sub-tiling - qk_mma_tiler K-dim = k_tile=256, SMEM fits at hd=512
This commit is contained in:
@@ -29,9 +29,13 @@ class FmhaKernel:
|
||||
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192; self.num_c_stage = 2
|
||||
self.threads_per_cta = 192
|
||||
# K-dim sub-tiling: cap at 256 to keep sQ and sK within SMEM budget
|
||||
self.k_tile = min(head_dim, 256)
|
||||
self.n_k_sub_tiles = head_dim // self.k_tile
|
||||
self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd
|
||||
self.q_stage = 1; self.num_c_stage = 2
|
||||
self.q_stage = 1
|
||||
self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512
|
||||
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
|
||||
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
|
||||
|
||||
@@ -39,7 +43,7 @@ class FmhaKernel:
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
# QK GEMM K-dim = head_dim. Each MMA sub-tile covers qk_ik*4 elements.
|
||||
# The tiler K must be head_dim so the QK loop iterates over all K sub-tiles.
|
||||
self.qk_mma_tiler = (128, 128, self.head_dim)
|
||||
self.qk_mma_tiler = (128, 128, self.k_tile)
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
|
||||
Reference in New Issue
Block a user