test: merge softmax+PV into single warp0 block (s_vals scope fix)

This commit is contained in:
2026-05-28 13:10:02 +00:00
parent 5c9e3c41af
commit 654a2ae7f4

View File

@@ -70,7 +70,7 @@ test_fmha_hd64(const bf16_t* q, const bf16_t* k, const bf16_t* v,
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Softmax
// Softmax + PV: warp 0 computes softmax and then PV in registers
if (wid == 0) {
float s_vals[SK], row_max = -INFINITY;
for (int n = 0; n < SK / 8; n++) {
@@ -94,12 +94,8 @@ test_fmha_hd64(const bf16_t* q, const bf16_t* k, const bf16_t* v,
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb+n*8),"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
}
tmem_fence_store();
}
__syncthreads();
// PV: O[d] = Σ P[0,j] × V[d,j] — computed in registers by warp 0
if (wid == 0) {
// P is in s_vals (lane 0). Compute O using s_vals directly (no TMEM re-read).
// PV: O[d] = Σ P[0,j] × V[d,j] — s_vals still in scope
if (lane == 0) {
for (int d = 0; d < HD; d++) {
float ov = 0.0f;