prefill_read_qk_rows was reading from address 0 (sg_off + n * 8) instead of tb + sg_off + n * 8. This caused garbage QK values, explaining the 0.928 cosine for T=1 and NaN for T>1.
prefill_read_qk_rows was reading from address 0 (sg_off + n * 8) instead of tb + sg_off + n * 8. This caused garbage QK values, explaining the 0.928 cosine for T=1 and NaN for T>1.