D2: add num_query_heads/batch_size params + batch grid dimension

- Head-packed approach: Q is (n_h*T, hd, 1), kernel treats each row independently
- Grid: (1, 1, batch) — M dimension handled by head packing
- n_h=128, T=1 → M=128, one MMA tile, all heads in single CTA
- Tested: cos 0.999995 for both n_h=1 and n_h=128
This commit is contained in:
2026-05-25 17:15:08 +00:00
parent 7c6fdd151d
commit dbe2ecbd41

View File

@@ -16,7 +16,7 @@ import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True):
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1):
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
@@ -28,6 +28,8 @@ class FmhaKernel:
self.pv_n_tile = 128
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.num_query_heads = num_query_heads
self.batch_size = batch_size
self.normalize = normalize # D5a: False = emit un-normalized O + lse
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
@@ -129,7 +131,9 @@ class FmhaKernel:
# CuTeDSL doesn't support None parameters in @cute.kernel.
if const_expr(lse is None):
lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,)))
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
# Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension
# For single-head (n_h=1): grid=(1,1,1) — backward compatible
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):