D2: revert multi-CTA grid params (using per-head launch approach instead)

This commit is contained in:
2026-05-24 22:52:21 +00:00
parent a5271821a8
commit e0339a92fc

View File

@@ -16,7 +16,7 @@ import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1):
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True):
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
@@ -43,12 +43,6 @@ class FmhaKernel:
self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
self.num_query_heads = num_query_heads
self.batch_size = batch_size
# D2: Multi-CTA grid. Each CTA handles one (head, batch) pair.
# Grid: (1, num_query_heads * batch_size, 1)
# Total CTAs = num_query_heads * batch_size
self.num_ctas = num_query_heads * batch_size
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
@@ -135,7 +129,7 @@ class FmhaKernel:
# CuTeDSL doesn't support None parameters in @cute.kernel.
if const_expr(lse is None):
lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,)))
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,self.num_ctas,1),block=[self.threads_per_cta,1,1],stream=stream)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):