D1: Fix TMA copies in k_sub path (no mbarrier, use cp_async wait)

This commit is contained in:
2026-05-24 04:53:46 +00:00
parent e637d3ae73
commit 98e974403c

View File

@@ -218,19 +218,17 @@ class FmhaKernel:
_v_bar = pipeline.NamedBarrier(barrier_id=7, num_threads=64) # TMA + MMA warps
for kt in range(self.n_kv_tiles):
for k_sub in range(self.n_k_sub_tiles):
# Load Q[k_sub] → sQ
# Load Q[k_sub] → sQ (no TMA barrier, use cp_async wait)
cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, 0)])
# Load K[k_sub] → sK
cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, 0)])
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)
# Sync with MMA warp: sQ and sK ready
_ksub_bar.arrive_and_wait()
# Load V[kt] → sV
cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, 0)])
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)
# Sync with MMA warp: sV ready for PV
_v_bar.arrive_and_wait()
else:
# Original pipeline path (hd≤256)