diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 6b60e81b..7e7aa8b8 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -392,7 +392,7 @@ class FmhaV3StageCMulti: num_threads=32 * len(self.epilogue_warp_id), ) epi_bar.arrive_and_wait() - cute.copy(tma_c, bSG_sC[(None, 0)], bSG_gC[(None, 0, 0)]) + cute.copy(tma_c, bSG_sC[(None, 0)], bSG_gC[(None, 0, 0, 0, 0, 0)]) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True)