CRITICAL FIX: add missing tb base in QK TMEM read address

prefill_read_qk_rows was reading from address 0 (sg_off + n * 8)
instead of tb + sg_off + n * 8. This caused garbage QK values,
explaining the 0.928 cosine for T=1 and NaN for T>1.
This commit is contained in:
2026-06-03 03:00:57 +00:00
parent 99b6de316b
commit eb69c3bfb9

View File

@@ -106,7 +106,7 @@ __device__ void prefill_read_qk_rows(uint32_t tb, float* sLogits,
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(sg_off + n * 8));
: "r"(tb + sg_off + n * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
int row = warp_row + lane;