From e45b94c01beea8507c2eefaa3b6289059749ef1a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 27 May 2026 06:44:37 +0000 Subject: [PATCH] Test: compare both normalized and un-normalized reference --- tests/unit/test_production.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_production.py b/tests/unit/test_production.py index 9a027da9..4f3690b9 100644 --- a/tests/unit/test_production.py +++ b/tests/unit/test_production.py @@ -16,24 +16,26 @@ def test_production_basic(): k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') - # PyTorch reference + # PyTorch reference (un-normalized) qf = q[0].float() kf = k.float() vf = v.float() scale = 1.0 / math.sqrt(hd) - attn = qf @ kf.T * scale - attn_max = attn.max(dim=-1, keepdim=True)[0] - attn_exp = torch.exp(attn - attn_max) + attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(qf @ kf.T * scale - attn_max) attn_sum = attn_exp.sum(dim=-1, keepdim=True) - ref = ((attn_exp / attn_sum) @ vf).unsqueeze(0) + ref_unnorm = attn_exp @ vf + ref_norm = (attn_exp / attn_sum) @ vf out = dsv4_attention(q, k, v) - cos = torch.nn.functional.cosine_similarity( - out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) + cos_unnorm = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref_unnorm.unsqueeze(0).flatten().unsqueeze(0) ).item() - status = "PASS" if cos >= 0.99 else "FAIL" - print(f" hd={hd}, n_h={n_h}, N={N}: cos {cos:.6f} {status}") + cos_norm = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref_norm.unsqueeze(0).flatten().unsqueeze(0) + ).item() + print(f" hd={hd}, n_h={n_h}, N={N}: cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f}") def test_production_multi_head():