diff --git a/tests/unit/test_fmha_hd64.cu b/tests/unit/test_fmha_hd64.cu index 528e385f..0513102b 100644 --- a/tests/unit/test_fmha_hd64.cu +++ b/tests/unit/test_fmha_hd64.cu @@ -70,7 +70,7 @@ test_fmha_hd64(const bf16_t* q, const bf16_t* k, const bf16_t* v, asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); - // Softmax + // Softmax + PV: warp 0 computes softmax and then PV in registers if (wid == 0) { float s_vals[SK], row_max = -INFINITY; for (int n = 0; n < SK / 8; n++) { @@ -94,12 +94,8 @@ test_fmha_hd64(const bf16_t* q, const bf16_t* k, const bf16_t* v, asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb+n*8),"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7)); } tmem_fence_store(); - } - __syncthreads(); - // PV: O[d] = Σ P[0,j] × V[d,j] — computed in registers by warp 0 - if (wid == 0) { - // P is in s_vals (lane 0). Compute O using s_vals directly (no TMEM re-read). + // PV: O[d] = Σ P[0,j] × V[d,j] — s_vals still in scope if (lane == 0) { for (int d = 0; d < HD; d++) { float ov = 0.0f;