CRITICAL FIX: add missing tb base in QK TMEM read address
prefill_read_qk_rows was reading from address 0 (sg_off + n * 8) instead of tb + sg_off + n * 8. This caused garbage QK values, explaining the 0.928 cosine for T=1 and NaN for T>1.
This commit is contained in:
@@ -106,7 +106,7 @@ __device__ void prefill_read_qk_rows(uint32_t tb, float* sLogits,
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(sg_off + n * 8));
|
||||
: "r"(tb + sg_off + n * 8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
|
||||
|
||||
int row = warp_row + lane;
|
||||
|
||||
Reference in New Issue
Block a user