fix: remove duplicate tmem free from epilogue (MMA warp handles dealloc)
This commit is contained in:
@@ -415,12 +415,9 @@ class FmhaV3StageC2:
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32)
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_pipe_cons = acc_pipe # full pipeline for epilogue_tma_store
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe_cons, c_pipe)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
epi_handle.release()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
for n in [128]:
|
||||
|
||||
Reference in New Issue
Block a user