diff --git a/tests/test_kv_cache_b200.py b/tests/test_kv_cache_b200.py index e5d13358..29625338 100644 --- a/tests/test_kv_cache_b200.py +++ b/tests/test_kv_cache_b200.py @@ -288,7 +288,7 @@ def main(): # Full BF16 reference qa_ref = normed @ qa_bf16_ref.T - kv_ref = normed @ kv_bf16_ref + kv_ref = normed @ kv_bf16_ref.T qa_n_ref = rms(qa_ref, qn, EPS) kv_n_ref = rms(kv_ref, kvn, EPS) q_ref = (qa_n_ref @ qb_bf16_ref.T).view(NT, NH, HD)