Debug: skip PV step entirely
This commit is contained in:
@@ -128,9 +128,8 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ===== STEP 3: PV GEMM (TS) =====
|
||||
// P(128,128) × V(128,16) → O(128,16)
|
||||
// 8 K-tiles: A = P cols [16*kt..16*kt+15), B = V K-tile kt
|
||||
// ===== STEP 3: PV GEMM (TS) — SKIPPED FOR DEBUG =====
|
||||
/*
|
||||
{
|
||||
uint32_t idesc_pv = make_idesc(BLOCK_MN, HD);
|
||||
|
||||
@@ -147,6 +146,7 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
}
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
*/
|
||||
|
||||
// ===== STEP 4: Epilogue — read O from TMEM =====
|
||||
// MMA output is scaled by 0.5, so multiply by 2.0
|
||||
|
||||
Reference in New Issue
Block a user