D1.0: Replace HEAD_DIM=64 with self.head_dim constructor parameter

This commit is contained in:
2026-05-23 05:55:03 +00:00
parent 249a581d8a
commit bd0b56dddd

View File

@@ -13,11 +13,10 @@ import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
HEAD_DIM = 64
class FmhaKernel:
def __init__(self, s_k=128, scale_softmax=None):
def __init__(self, head_dim=64, s_k=128, scale_softmax=None):
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
@@ -27,17 +26,17 @@ class FmhaKernel:
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(HEAD_DIM)
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
self.pv_mma_tiler = (128, self.head_dim, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2])
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.head_dim, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
@@ -77,14 +76,14 @@ class FmhaKernel:
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(HEAD_DIM, self.s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * self.s_k),
(self.head_dim, self.s_k, 1),
stride=(1, self.head_dim, self.head_dim * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.head_dim), tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
@@ -255,7 +254,7 @@ class FmhaKernel:
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i)
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i)
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
n_corr_tiles = HEAD_DIM // corr_tile_size
n_corr_tiles = self.head_dim // corr_tile_size
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()