D1: Fix TMA copies in k_sub path (no mbarrier, use cp_async wait)
This commit is contained in:
@@ -218,19 +218,17 @@ class FmhaKernel:
|
||||
_v_bar = pipeline.NamedBarrier(barrier_id=7, num_threads=64) # TMA + MMA warps
|
||||
for kt in range(self.n_kv_tiles):
|
||||
for k_sub in range(self.n_k_sub_tiles):
|
||||
# Load Q[k_sub] → sQ
|
||||
# Load Q[k_sub] → sQ (no TMA barrier, use cp_async wait)
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, 0)])
|
||||
# Load K[k_sub] → sK
|
||||
cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, 0)])
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
# Sync with MMA warp: sQ and sK ready
|
||||
_ksub_bar.arrive_and_wait()
|
||||
# Load V[kt] → sV
|
||||
cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, 0)])
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
# Sync with MMA warp: sV ready for PV
|
||||
_v_bar.arrive_and_wait()
|
||||
else:
|
||||
# Original pipeline path (hd≤256)
|
||||
|
||||
Reference in New Issue
Block a user