diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index f02eaa6a..4b24d492 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -561,6 +561,8 @@ class FmhaKernel: cute.group_modes(sC, 0, 2), cute.group_modes(tCgC_epi, 0, 2), ) + # Slice gC with MMA tile coordinates (same as epilogue_tma_store) + bSG_gC = bSG_gC[(None, None, None, Int32(0), Int32(0), Int32(0))] # TMA store: only the first epilogue warp does the copy c_pipe = pipeline.PipelineTmaStore.create( num_stages=self.num_c_stage,