D1: K sub-tiling - qk_mma_tiler K-dim = k_tile=256, SMEM fits at hd=512

This commit is contained in:
2026-05-24 04:41:12 +00:00
parent 4564a264db
commit 2e732ce3a7

View File

@@ -29,9 +29,13 @@ class FmhaKernel:
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
self.threads_per_cta = 192
# K-dim sub-tiling: cap at 256 to keep sQ and sK within SMEM budget
self.k_tile = min(head_dim, 256)
self.n_k_sub_tiles = head_dim // self.k_tile
self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd
self.q_stage = 1; self.num_c_stage = 2
self.q_stage = 1
self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
@@ -39,7 +43,7 @@ class FmhaKernel:
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
# QK GEMM K-dim = head_dim. Each MMA sub-tile covers qk_ik*4 elements.
# The tiler K must be head_dim so the QK loop iterates over all K sub-tiles.
self.qk_mma_tiler = (128, 128, self.head_dim)
self.qk_mma_tiler = (128, 128, self.k_tile)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler