diff --git a/tests/unit/test_fmha_v3_stage_c2.py b/tests/unit/test_fmha_v3_stage_c2.py index ea2d85c0..36d66b44 100644 --- a/tests/unit/test_fmha_v3_stage_c2.py +++ b/tests/unit/test_fmha_v3_stage_c2.py @@ -415,12 +415,9 @@ class FmhaV3StageC2: acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1) c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32) c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) - acc_pipe_cons = acc_pipe # full pipeline for epilogue_tma_store - acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe_cons, c_pipe) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe) c_pipe.producer_tail() epi_handle.release() - tmem.relinquish_alloc_permit() - tmem.free(tmem_ptr) def test(): torch.manual_seed(42) for n in [128]: