diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index c999c0bd..39b190e0 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -626,7 +626,7 @@ class FmhaKernel: gO_qdl = cute.flat_divide( mC, cute.select(self.pv_mma_tiler, mode=[0, 1]) ) - gO = gO_qdl[None, None, None, Int32(0), (Int32(0), Int32(0))] + gO = gO_qdl[None, None, None, Int32(0), Int32(0)] tOsO, tOgO = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), cute.group_modes(sC, 0, 2),