fix: dummy tma_p reuse tma_q for non-SMEM-P

This commit is contained in:
2026-05-24 03:33:23 +00:00
parent 00c9991396
commit 9ea7f511b0

View File

@@ -113,19 +113,36 @@ class FmhaKernel:
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
# SMEM-P: gP buffer and TMA for P (GMEM→SMEM via TMA)
if self.use_smem_p:
p_s = cute.slice_(self.p_smem_s,(None,None,None,0))
gP = torch.zeros(128, self.s_k, dtype=torch.bfloat16, device='cuda')
mgP = ct.from_dlpack(gP).mark_layout_dynamic(leading_dim=ct.get_leading_dim(gP))
tma_p,mgP = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, pv_mma.thr_id),
mgP, p_s, self.qk_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape
)
else:
# Dummy gP — tma_p not needed (dead code path)
gP = torch.zeros(1, 1, dtype=torch.bfloat16, device='cuda')
mgP = ct.from_dlpack(gP).mark_layout_dynamic(leading_dim=ct.get_leading_dim(gP))
tma_p = tma_q # reuse Q's TMA descriptor as dummy (dead code)
# Always create a valid mLSE tensor for the kernel.
# CuTeDSL doesn't support None parameters in @cute.kernel.
# For normalize=True, mLSE is unused (dead-code-eliminated by compiler).
if const_expr(lse is None):
lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,)))
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,tma_p,mgP,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, tma_p, mgP, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
if const_expr(self.use_smem_p):
cpasync.prefetch_descriptor(tma_p)
@cute.struct
class SS:
@@ -226,6 +243,12 @@ class FmhaKernel:
sh.commit()
softmax_done_bar.arrive_and_wait()
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
if const_expr(self.use_smem_p):
# SMEM-P: TMA load gP → sP (MMA warp does this after barrier)
tPgP, tPsP = cpasync.tma_partition(tma_p, 0, cute.nvgpu.OperandMajorMode.M, cute.group_modes(sP,0,3), cute.group_modes(mgP,0,3))
cute.copy(tma_p, tPsP[(None,0,None,0)], tPgP[(None,0,None,0)], tma_bar_ptr=st.s_bar.data_ptr())
cpasync.commit_group()
cpasync.wait_group(0)
if not self.use_smem_p:
# TMEM-P: PV reads P from TMEM
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
@@ -366,20 +389,19 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using make_tiled_copy_C + retile.
# The retile() call transforms rP_bf16 from QK C-fragment layout
# to the SMEM copy's source layout, matching partition_D(sP).
_smem_p_store_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.q_dtype,
num_bits_per_copy=16,
)
_tiled_smem_p = cute.make_tiled_copy_C(_smem_p_store_atom, qk_mma)
_thr_smem_p = _tiled_smem_p.get_slice(sfw_idx)
_tRS_sP = _thr_smem_p.partition_D(_sP_nostage)
_tRS_rP = _tiled_smem_p.retile(rP_bf16)
cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP)
cute.arch.fence_proxy("async.shared", space="cta")
# SMEM-P: write P to gP (global memory), then TMA loads gP→sP.
# rP_bf16 and tCgP are both partitioned by the QK MMA's C-fragment,
# so they have the same thread→value mapping. A simple element-wise
# copy from rP_bf16 to tCgP puts P values at the correct gP positions.
gP_local = cute.local_tile(mgP, (128, self.s_k), (0, 0))
tCgP = qk_thr.partition_C(gP_local)
# Flatten both tensors for element-wise copy
rP_flat = cute.make_tensor(rP_bf16.iterator, cute.coalesce(rP_bf16.layout))
gP_flat = cute.make_tensor(tCgP.iterator, cute.coalesce(tCgP.layout))
# Copy element-by-element (both should have 128 values per thread)
for idx in cutlass.range(cute.size(rP_flat), vectorize=True):
gP_flat[idx] = rP_flat[idx]
cute.arch.fence_proxy("async.global", space="cta")
if kt > 0:
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
@@ -443,7 +465,7 @@ class FmhaKernel:
if sfw_idx == 0:
_ln2 = Float32(0.6931471805599453) # ln(2)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
mLSE[0] = lse_val
mLSE[0] = lse_val.to(self.q_dtype)
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)