diff --git a/tests/unit/test_production.py b/tests/unit/test_production.py index 9a027da9..4f3690b9 100644 --- a/tests/unit/test_production.py +++ b/tests/unit/test_production.py @@ -16,24 +16,26 @@ def test_production_basic(): k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') - # PyTorch reference + # PyTorch reference (un-normalized) qf = q[0].float() kf = k.float() vf = v.float() scale = 1.0 / math.sqrt(hd) - attn = qf @ kf.T * scale - attn_max = attn.max(dim=-1, keepdim=True)[0] - attn_exp = torch.exp(attn - attn_max) + attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(qf @ kf.T * scale - attn_max) attn_sum = attn_exp.sum(dim=-1, keepdim=True) - ref = ((attn_exp / attn_sum) @ vf).unsqueeze(0) + ref_unnorm = attn_exp @ vf + ref_norm = (attn_exp / attn_sum) @ vf out = dsv4_attention(q, k, v) - cos = torch.nn.functional.cosine_similarity( - out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) + cos_unnorm = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref_unnorm.unsqueeze(0).flatten().unsqueeze(0) ).item() - status = "PASS" if cos >= 0.99 else "FAIL" - print(f" hd={hd}, n_h={n_h}, N={N}: cos {cos:.6f} {status}") + cos_norm = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref_norm.unsqueeze(0).flatten().unsqueeze(0) + ).item() + print(f" hd={hd}, n_h={n_h}, N={N}: cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f}") def test_production_multi_head():