fix: epilogue_warp_id must be tuple for epilogue_tma_store, check with [0]

This commit is contained in:
2026-05-22 09:59:20 +00:00
parent 6ba12b7890
commit 59f4d8a469

View File

@@ -26,7 +26,7 @@ class FmhaV3StageC2:
self.softmax_warp_ids = (0, 1, 2, 3)
self.correction_warp_ids = (4, 5, 6, 7)
self.mma_warp_id = 8; self.tma_warp_id = 9
self.epilogue_warp_id = 10; self.empty_warp_id = 11
self.epilogue_warp_id = (10,); self.empty_warp_id = 11
self.threads_per_cta = 32 * 12
# Pipeline stages
self.mma_softmax_stage = 1; self.softmax_corr_stage = 1
@@ -403,7 +403,7 @@ class FmhaV3StageC2:
cute.arch.mbarrier_arrive(st.tmem_dealloc)
# ==================== EPILOGUE WARP (10) ====================
if warp_idx == self.epilogue_warp_id:
if warp_idx == self.epilogue_warp_id[0]:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
epi_handle = corr_epi_cons.wait_and_advance()