From 2e732ce3a72473ba6ddbd9e8e8a9d60e224b4258 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 04:41:12 +0000 Subject: [PATCH] D1: K sub-tiling - qk_mma_tiler K-dim = k_tile=256, SMEM fits at hd=512 --- dsv4/kernels/attention/fmha.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index fc1442a5..2e91e554 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -29,9 +29,13 @@ class FmhaKernel: self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 - self.threads_per_cta = 192; self.num_c_stage = 2 + self.threads_per_cta = 192 + # K-dim sub-tiling: cap at 256 to keep sQ and sK within SMEM budget + self.k_tile = min(head_dim, 256) + self.n_k_sub_tiles = head_dim // self.k_tile self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd - self.q_stage = 1; self.num_c_stage = 2 + self.q_stage = 1 + self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512 self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) @@ -39,7 +43,7 @@ class FmhaKernel: qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) # QK GEMM K-dim = head_dim. Each MMA sub-tile covers qk_ik*4 elements. # The tiler K must be head_dim so the QK loop iterates over all K sub-tiles. - self.qk_mma_tiler = (128, 128, self.head_dim) + self.qk_mma_tiler = (128, 128, self.k_tile) pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) self.mma_tiler = self.qk_mma_tiler