Add c_simple GMEM tensor (non-dynamic) for SMEM accumulator TMA store

This commit is contained in:
2026-05-27 05:33:30 +00:00
parent 2438826eee
commit b02e103ac0
2 changed files with 15 additions and 6 deletions

View File

@@ -126,7 +126,7 @@ class FmhaKernel:
cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None):
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None, c_simple=None):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
@@ -171,11 +171,14 @@ class FmhaKernel:
# For single-head (n_h=1): grid=(1,1,1) — backward compatible
if const_expr(row_sums is None):
row_sums = cute.make_tensor(lse.iterator, lse.layout)
# c_simple: non-dynamic-layout GMEM tensor for direct TMA store (SMEM accumulator path)
if const_expr(c_simple is None):
c_simple = cute.make_tensor(c.iterator, cute.make_layout((1, 1, 1), stride=(1, 1, 1)))
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums,c_simple).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums):
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums, mCSimple):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
@@ -624,7 +627,7 @@ class FmhaKernel:
cute.arch.fence_proxy("async.shared", space="cta")
# Step 2: TMA store sC_flat -> GMEM
gO = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler, (None, None, 0)), (None, None, None))
gO = cute.local_tile(mCSimple, cute.slice_(self.pv_mma_tiler, (None, None, 0)), (None, None, None))
# Group modes to match: sC_flat is 2D, gO needs to be grouped to 2D
tOsC, tOgO = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),

View File

@@ -50,8 +50,12 @@ def test_smem_acc(hd=64, s_k=256, use_smem_p=False, normalize=False):
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
# Simple GMEM tensor (non-dynamic-layout) for SMEM accumulator TMA store
c_simple_tensor = c_tile.clone()
mCSimple = ct.from_dlpack(c_simple_tensor) # No mark_layout_dynamic!
print(f' hd={hd}, s_k={s_k} ({n_kv_tiles} KV tiles, pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}): Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS, c_simple=mCSimple)
for nt in range(n_pv_tiles):
v_start = nt * pv_n_tile
@@ -69,7 +73,9 @@ def test_smem_acc(hd=64, s_k=256, use_smem_p=False, normalize=False):
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
compiled(mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS)
mCSimple = ct.from_dlpack(c_tile) # No mark_layout_dynamic!
compiled(mQ, mK, mV, mC, stream, lse=mLSE, row_sums=mRS, c_simple=mCSimple)
torch.cuda.synchronize()
c[:, v_start:v_end, :] = c_tile