Fix kv_ref transpose in KV cache test
This commit is contained in:
@@ -288,7 +288,7 @@ def main():
|
||||
|
||||
# Full BF16 reference
|
||||
qa_ref = normed @ qa_bf16_ref.T
|
||||
kv_ref = normed @ kv_bf16_ref
|
||||
kv_ref = normed @ kv_bf16_ref.T
|
||||
qa_n_ref = rms(qa_ref, qn, EPS)
|
||||
kv_n_ref = rms(kv_ref, kvn, EPS)
|
||||
q_ref = (qa_n_ref @ qb_bf16_ref.T).view(NT, NH, HD)
|
||||
|
||||
Reference in New Issue
Block a user