diff --git a/tests/unit/test_fmha_ts_full.cu b/tests/unit/test_fmha_ts_full.cu index 77c9888a..474491bf 100644 --- a/tests/unit/test_fmha_ts_full.cu +++ b/tests/unit/test_fmha_ts_full.cu @@ -128,9 +128,8 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, } __syncthreads(); - // ===== STEP 3: PV GEMM (TS) ===== - // P(128,128) × V(128,16) → O(128,16) - // 8 K-tiles: A = P cols [16*kt..16*kt+15), B = V K-tile kt + // ===== STEP 3: PV GEMM (TS) — SKIPPED FOR DEBUG ===== + /* { uint32_t idesc_pv = make_idesc(BLOCK_MN, HD); @@ -147,6 +146,7 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, } asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); + */ // ===== STEP 4: Epilogue — read O from TMEM ===== // MMA output is scaled by 0.5, so multiply by 2.0