Merge branch 'master' of ssh://sweetapi.com:2222/biondizzle/nvfp4-megamoe-kernel
This commit is contained in:
@@ -89,7 +89,7 @@ class FmhaKernel:
|
||||
cute.size_in_bytes(self.q_dtype, v_s)) * cta
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q, k, v, c, stream, lse=None):
|
||||
def __call__(self, q, k, v, c, stream, lse=None, gP=None, tma_p=None):
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
@@ -113,19 +113,33 @@ class FmhaKernel:
|
||||
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
|
||||
epi_s = cute.select(self.c_smem_s,mode=[0,1])
|
||||
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
|
||||
|
||||
# SMEM-P: gP buffer and TMA for P (GMEM→SMEM via TMA)
|
||||
if self.use_smem_p and gP is not None:
|
||||
p_s = cute.slice_(self.p_smem_s,(None,None,None,0))
|
||||
tma_p,gP = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
gP, p_s, self.qk_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape
|
||||
)
|
||||
elif not self.use_smem_p:
|
||||
tma_p = tma_q # dummy, dead code
|
||||
else:
|
||||
raise ValueError("use_smem_p=True but no gP provided")
|
||||
# Always create a valid mLSE tensor for the kernel.
|
||||
# CuTeDSL doesn't support None parameters in @cute.kernel.
|
||||
# For normalize=True, mLSE is unused (dead-code-eliminated by compiler).
|
||||
if const_expr(lse is None):
|
||||
lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,)))
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,tma_p,gP,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, tma_p, mGP, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx,_,_ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
if const_expr(self.use_smem_p):
|
||||
cpasync.prefetch_descriptor(tma_p)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
@@ -226,6 +240,12 @@ class FmhaKernel:
|
||||
sh.commit()
|
||||
softmax_done_bar.arrive_and_wait()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
|
||||
if const_expr(self.use_smem_p):
|
||||
# SMEM-P: TMA load gP → sP (MMA warp does this after barrier)
|
||||
tPgP, tPsP = cpasync.tma_partition(tma_p, 0, cute.nvgpu.OperandMajorMode.M, cute.group_modes(sP,0,3), cute.group_modes(mGP,0,3))
|
||||
cute.copy(tma_p, tPsP[(None,0,None,0)], tPgP[(None,0,None,0)], tma_bar_ptr=st.s_bar.data_ptr())
|
||||
cpasync.commit_group()
|
||||
cpasync.wait_group(0)
|
||||
if not self.use_smem_p:
|
||||
# TMEM-P: PV reads P from TMEM
|
||||
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
|
||||
@@ -368,55 +388,19 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: write P to sP using coordinate-indexed store.
|
||||
# Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates.
|
||||
# The sP layout is PV A-operand SMEM: ((128,16),1,(4,2),1) for hd=64,
|
||||
# but changes at larger hd. We write using sP's actual subtile structure.
|
||||
# sP[(m_sub, k_sub), 0, (k_group1, k_group2)]
|
||||
# where m is decomposed as m_sub in [0, 128) with subtile 16,
|
||||
# and k is decomposed as k_sub in [0, 16), k_group1 in [0, 4), k_group2.
|
||||
#
|
||||
# The QK C-fragment has 128 columns (s_k=128). At hd=64, P uses 64 of them.
|
||||
# At hd=128, P uses all 128.
|
||||
# The sP subtile pattern depends on pv_mma's K-dim (pv_n_tile / 128 * qk_mma_tiler[2]).
|
||||
#
|
||||
# We iterate over the identity tensor to get (m, k) for each P value,
|
||||
# then compute sP indices from the actual sP shape.
|
||||
#
|
||||
# For the sP layout ((M_atom, K_atom), 1, (K_group1, K_group2), stage):
|
||||
# m_idx = m_coord (0..127)
|
||||
# k_idx = k_coord (0..min(s_k, pv_n_tile)-1)
|
||||
# sub_m = m_idx % M_atom, group_m = m_idx // M_atom
|
||||
# sub_k = k_idx % K_atom, group_k1 = (k_idx // K_atom) % K_group1_size, group_k2 = k_idx // (K_atom * K_group1_size)
|
||||
#
|
||||
# We read from sP's shape to determine the tiling.
|
||||
_sP_shape = cute.shape(_sP_nostage)
|
||||
_M_atom = _sP_shape[0][0] # e.g. 128
|
||||
_K_atom = _sP_shape[0][1] # e.g. 16
|
||||
_K_g1 = _sP_shape[2][0] # e.g. 4
|
||||
_K_g2 = _sP_shape[2][1] # e.g. 2
|
||||
for j0 in range(32):
|
||||
for j1 in range(4):
|
||||
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
|
||||
m_coord = coord[0]
|
||||
k_coord = coord[1]
|
||||
# Skip if k_coord is beyond P's columns
|
||||
# P has p_cols_fp32 * (qk_acc_dtype.width / q_dtype.width) BF16 elements per row
|
||||
# p_cols_fp32 = pv_n_tile * q_dtype.width / qk_acc_dtype.width
|
||||
# At hd=64: p_cols_bf16 = 64, so k ranges 0..63
|
||||
# At hd=128: p_cols_bf16 = 128, so k ranges 0..127
|
||||
# The identity tensor maps all 128 QK columns. We write all of them
|
||||
# (PV only reads the first pv_n_tile columns from SMEM).
|
||||
sub_m = m_coord % _K_atom # within-atom M index
|
||||
grp_m = m_coord // _K_atom # M group (should be 0 for 128-row tile)
|
||||
# Actually, sP mode 0 is (M_atom, K_atom), so M is the first sub-mode
|
||||
# Let me re-examine: mode 0 is (128, 16) which is (M, K_subtile)
|
||||
# So m_idx maps directly to mode 0 first dim, k_idx maps to mode 0 second dim + modes 2
|
||||
k_sub = k_coord % _K_atom
|
||||
k_g1 = (k_coord // _K_atom) % _K_g1
|
||||
k_g2 = k_coord // (_K_atom * _K_g1)
|
||||
_sP_nostage[(m_coord, k_sub), 0, (k_g1, k_g2)] = rP_bf16[(j0, 0), j1, 0, 0]
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
# SMEM-P: write P to gP (global memory), then TMA loads gP→sP.
|
||||
# rP_bf16 and tCgP are both partitioned by the QK MMA's C-fragment,
|
||||
# so they have the same thread→value mapping. A simple element-wise
|
||||
# copy from rP_bf16 to tCgP puts P values at the correct gP positions.
|
||||
gP_local = cute.local_tile(mGP, (128, self.s_k), (0, 0))
|
||||
tCgP = qk_thr.partition_C(gP_local)
|
||||
# Flatten both tensors for element-wise copy
|
||||
rP_flat = cute.make_tensor(rP_bf16.iterator, cute.coalesce(rP_bf16.layout))
|
||||
gP_flat = cute.make_tensor(tCgP.iterator, cute.coalesce(tCgP.layout))
|
||||
# Copy element-by-element (both should have 128 values per thread)
|
||||
for idx in cutlass.range(cute.size(rP_flat), vectorize=True):
|
||||
gP_flat[idx] = rP_flat[idx]
|
||||
cute.arch.fence_proxy("async.global", space="cta")
|
||||
if kt > 0:
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
@@ -480,7 +464,7 @@ class FmhaKernel:
|
||||
if sfw_idx == 0:
|
||||
_ln2 = Float32(0.6931471805599453) # ln(2)
|
||||
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
|
||||
mLSE[0] = lse_val
|
||||
mLSE[0] = lse_val.to(self.q_dtype)
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
Reference in New Issue
Block a user