Stage D1: Multi-PV-tile support for hd>256 (tcgen05 MMA max N=256)

This commit is contained in:
2026-05-23 09:04:01 +00:00
parent f556060ddf
commit 2c36cd0d32

View File

@@ -15,11 +15,10 @@ import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None):
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None):
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
@@ -32,7 +31,6 @@ class FmhaKernel:
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)