Fix kv_ref transpose in KV cache test

This commit is contained in:
2026-05-19 08:58:46 +00:00
parent c1099d76d2
commit d60673864a

View File

@@ -288,7 +288,7 @@ def main():
# Full BF16 reference
qa_ref = normed @ qa_bf16_ref.T
kv_ref = normed @ kv_bf16_ref
kv_ref = normed @ kv_bf16_ref.T
qa_n_ref = rms(qa_ref, qn, EPS)
kv_n_ref = rms(kv_ref, kvn, EPS)
q_ref = (qa_n_ref @ qb_bf16_ref.T).view(NT, NH, HD)