test: merge softmax+PV into single warp0 block (s_vals scope fix)
This commit is contained in:
@@ -70,7 +70,7 @@ test_fmha_hd64(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Softmax
|
||||
// Softmax + PV: warp 0 computes softmax and then PV in registers
|
||||
if (wid == 0) {
|
||||
float s_vals[SK], row_max = -INFINITY;
|
||||
for (int n = 0; n < SK / 8; n++) {
|
||||
@@ -94,12 +94,8 @@ test_fmha_hd64(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb+n*8),"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
|
||||
}
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// PV: O[d] = Σ P[0,j] × V[d,j] — computed in registers by warp 0
|
||||
if (wid == 0) {
|
||||
// P is in s_vals (lane 0). Compute O using s_vals directly (no TMEM re-read).
|
||||
// PV: O[d] = Σ P[0,j] × V[d,j] — s_vals still in scope
|
||||
if (lane == 0) {
|
||||
for (int d = 0; d < HD; d++) {
|
||||
float ov = 0.0f;
|
||||
|
||||
Reference in New Issue
Block a user