Test: compare both normalized and un-normalized reference

This commit is contained in:
2026-05-27 06:44:37 +00:00
parent b70ab2a6ee
commit e45b94c01b

View File

@@ -16,24 +16,26 @@ def test_production_basic():
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
# PyTorch reference
# PyTorch reference (un-normalized)
qf = q[0].float()
kf = k.float()
vf = v.float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
attn_max = attn.max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(attn - attn_max)
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref = ((attn_exp / attn_sum) @ vf).unsqueeze(0)
ref_unnorm = attn_exp @ vf
ref_norm = (attn_exp / attn_sum) @ vf
out = dsv4_attention(q, k, v)
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
cos_unnorm = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref_unnorm.unsqueeze(0).flatten().unsqueeze(0)
).item()
status = "PASS" if cos >= 0.99 else "FAIL"
print(f" hd={hd}, n_h={n_h}, N={N}: cos {cos:.6f} {status}")
cos_norm = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref_norm.unsqueeze(0).flatten().unsqueeze(0)
).item()
print(f" hd={hd}, n_h={n_h}, N={N}: cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f}")
def test_production_multi_head():