diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index e44a3857..9b9486f0 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -636,7 +636,7 @@ class FmhaKernel: ) c_pipe.producer_acquire() if warp_idx == self.epilogue_warp_id[0]: - cute.copy(tma_c, bSG_sC[(None, Int32(0))], bSG_gC[(None, Int32(0))]) + cute.copy(tma_c, bSG_sC[(None, None, Int32(0))], bSG_gC[(None, None, Int32(0))]) c_pipe.producer_commit() c_pipe.producer_tail()