diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index 15ec7b42..9e52d7b7 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -634,7 +634,7 @@ class FmhaKernel: cute.group_modes(gO, 0, 2), ) if warp_idx == self.epilogue_warp_id[0]: - cute.copy(tma_c, tOsC[None, Int32(0)], tOgO[None, Int32(0)]) + cute.copy(tma_c, tOsC, tOgO) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True)