From 59f4d8a4697f1918b5cc874da535e529467059f9 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 09:59:20 +0000 Subject: [PATCH] fix: epilogue_warp_id must be tuple for epilogue_tma_store, check with [0] --- tests/unit/test_fmha_v3_stage_c2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c2.py b/tests/unit/test_fmha_v3_stage_c2.py index d79ef7e3..2f54a7c5 100644 --- a/tests/unit/test_fmha_v3_stage_c2.py +++ b/tests/unit/test_fmha_v3_stage_c2.py @@ -26,7 +26,7 @@ class FmhaV3StageC2: self.softmax_warp_ids = (0, 1, 2, 3) self.correction_warp_ids = (4, 5, 6, 7) self.mma_warp_id = 8; self.tma_warp_id = 9 - self.epilogue_warp_id = 10; self.empty_warp_id = 11 + self.epilogue_warp_id = (10,); self.empty_warp_id = 11 self.threads_per_cta = 32 * 12 # Pipeline stages self.mma_softmax_stage = 1; self.softmax_corr_stage = 1 @@ -403,7 +403,7 @@ class FmhaV3StageC2: cute.arch.mbarrier_arrive(st.tmem_dealloc) # ==================== EPILOGUE WARP (10) ==================== - if warp_idx == self.epilogue_warp_id: + if warp_idx == self.epilogue_warp_id[0]: tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) epi_handle = corr_epi_cons.wait_and_advance()