D1: Fix pipeline API for K sub-tile path (producer_acquire/commit)

This commit is contained in:
2026-05-24 04:59:41 +00:00
parent 9afef9ed7d
commit 0db580f18a

View File

@@ -217,17 +217,19 @@ class FmhaKernel:
for kt in range(self.n_kv_tiles):
for k_sub in range(self.n_k_sub_tiles):
# Load Q[k_sub] → sQ
qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.producer_acquire(qp_stage)
cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qp_stage.index)], tma_bar_ptr=qp_stage.barrier)
qp.producer_commit(qp_stage); qp_stage.advance()
# Load K[k_sub] → sK
kvh = kvp.acquire_and_advance()
cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
qh = qp.wait_and_advance(); qh.release()
kvh = kvp.wait_and_advance(); pk = cutlass.Boolean(1)
kvp.producer_acquire(kvp_stage)
cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, kvp_stage.index)], tma_bar_ptr=kvp_stage.barrier)
kvp.producer_commit(kvp_stage); kvp_stage.advance()
# Load V[kt] → sV
# (V doesn't depend on k_sub, load once per kt)
# V is already loaded in the K/V pipeline
kvp.tail()
kvp.producer_acquire(kvp_stage)
cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, kvp_stage.index)], tma_bar_ptr=kvp_stage.barrier)
kvp.producer_commit(kvp_stage); kvp_stage.advance()
qp.producer_tail(qp_stage)
kvp.producer_tail(kvp_stage)
else:
# Original pipeline path (hd≤256)
qp.reset(); qh = qp.acquire_and_advance()
@@ -246,6 +248,8 @@ class FmhaKernel:
tmem.wait_for_alloc()
if const_expr(self.n_k_sub_tiles > 1):
# K sub-tiling path (hd=512): pipeline sync for k_sub iterations
qp_stage = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.q_stage)
kvp_stage = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.kv_stage)
for kt in range(self.n_kv_tiles):
for k_sub in range(self.n_k_sub_tiles):
qh = qc.wait_and_advance(); qh.release()