D1: K sub-tile MMA path using pipeline barriers

This commit is contained in:
2026-05-24 04:57:08 +00:00
parent fd28718483
commit e93dabe43c

View File

@@ -213,23 +213,21 @@ class FmhaKernel:
# ===== TMA LOAD warp =====
if warp_idx == self.tma_warp_id:
if const_expr(self.n_k_sub_tiles > 1):
# K sub-tiling path (hd=512): serialized TMA loads with barrier sync
_ksub_bar = pipeline.NamedBarrier(barrier_id=6, num_threads=64) # TMA + MMA warps
_v_bar = pipeline.NamedBarrier(barrier_id=7, num_threads=64) # TMA + MMA warps
# K sub-tiling path (hd=512): load Q and K per k_sub using pipeline barriers
for kt in range(self.n_kv_tiles):
for k_sub in range(self.n_k_sub_tiles):
# Load Q[k_sub] → sQ (no TMA barrier, use cp_async wait)
cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, 0)])
# Load Q[k_sub] → sQ
qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
# Load K[k_sub] → sK
cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, 0)])
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)
_ksub_bar.arrive_and_wait()
kvh = kvp.acquire_and_advance()
cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
qh = qp.wait_and_advance(); qh.release()
kvh = kvp.wait_and_advance(); pk = cutlass.Boolean(1)
# Load V[kt] → sV
cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, 0)])
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)
_v_bar.arrive_and_wait()
# (V doesn't depend on k_sub, load once per kt)
# V is already loaded in the K/V pipeline
kvp.tail()
else:
# Original pipeline path (hd≤256)
qp.reset(); qh = qp.acquire_and_advance()
@@ -247,36 +245,30 @@ class FmhaKernel:
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
if const_expr(self.n_k_sub_tiles > 1):
# K sub-tiling path (hd=512): serialized with barrier sync
_ksub_bar = pipeline.NamedBarrier(barrier_id=6, num_threads=64)
_v_bar = pipeline.NamedBarrier(barrier_id=7, num_threads=64)
# K sub-tiling path (hd=512): pipeline sync for k_sub iterations
for kt in range(self.n_kv_tiles):
for k_sub in range(self.n_k_sub_tiles):
# Wait for TMA warp: sQ and sK ready
_ksub_bar.arrive_and_wait()
# QK GEMM: load Q from sQ, K from sK, accumulate into S
qh = qc.wait_and_advance(); qh.release()
kvh = kvc.wait_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, k_sub != 0)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,0)], tStS0)
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
# After all k_sub: S has full QK for this kt
cute.arch.fence_view_async_tmem_store()
# Signal softmax
softmax_done_bar.arrive() # signal that S is ready
# Wait for V load
_v_bar.arrive_and_wait()
# PV GEMM
softmax_done_bar.arrive()
softmax_done_bar.arrive_and_wait()
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
if not self.use_smem_p:
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,0)], tOtO0)
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
else:
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,0)], tOtO0)
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
softmax_done_bar.arrive() # signal PV complete for this kt
kvh.release()
final_o_bar.arrive()
else:
# Original pipeline path (hd≤256)