From 0db580f18a436eb0bf70ade3433d429b3b325cf7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 04:59:41 +0000 Subject: [PATCH] D1: Fix pipeline API for K sub-tile path (producer_acquire/commit) --- dsv4/kernels/attention/fmha.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 07a80476..3ac2b220 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -217,17 +217,19 @@ class FmhaKernel: for kt in range(self.n_kv_tiles): for k_sub in range(self.n_k_sub_tiles): # Load Q[k_sub] → sQ - qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + qp.producer_acquire(qp_stage) + cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qp_stage.index)], tma_bar_ptr=qp_stage.barrier) + qp.producer_commit(qp_stage); qp_stage.advance() # Load K[k_sub] → sK - kvh = kvp.acquire_and_advance() - cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - qh = qp.wait_and_advance(); qh.release() - kvh = kvp.wait_and_advance(); pk = cutlass.Boolean(1) + kvp.producer_acquire(kvp_stage) + cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, kvp_stage.index)], tma_bar_ptr=kvp_stage.barrier) + kvp.producer_commit(kvp_stage); kvp_stage.advance() # Load V[kt] → sV - # (V doesn't depend on k_sub, load once per kt) - # V is already loaded in the K/V pipeline - kvp.tail() + kvp.producer_acquire(kvp_stage) + cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, kvp_stage.index)], tma_bar_ptr=kvp_stage.barrier) + kvp.producer_commit(kvp_stage); kvp_stage.advance() + qp.producer_tail(qp_stage) + kvp.producer_tail(kvp_stage) else: # Original pipeline path (hd≤256) qp.reset(); qh = qp.acquire_and_advance() @@ -246,6 +248,8 @@ class FmhaKernel: tmem.wait_for_alloc() if const_expr(self.n_k_sub_tiles > 1): # K sub-tiling path (hd=512): pipeline sync for k_sub iterations + qp_stage = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.q_stage) + kvp_stage = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.kv_stage) for kt in range(self.n_kv_tiles): for k_sub in range(self.n_k_sub_tiles): qh = qc.wait_and_advance(); qh.release()