diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 33f1b051..7214f10c 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -221,15 +221,15 @@ class FmhaKernel: # Load Q[k_sub] → sQ qh = qp.acquire_and_advance() cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + qh.commit() # Load K[k_sub] → sK kvh = kvp.acquire_and_advance() cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - qh = qp.wait_and_advance(); qh.release() - kvh = kvp.wait_and_advance(); pk = cutlass.Boolean(1) + kvh.commit() # Load V[kt] → sV - kvh = kvp.acquire_and_advance(pk) + kvh = kvp.acquire_and_advance() cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - kvh = kvp.wait_and_advance(); pk = cutlass.Boolean(1) + kvh.commit() qp.tail() kvp.tail() else: