Test: compare both normalized and un-normalized reference
This commit is contained in:
@@ -16,24 +16,26 @@ def test_production_basic():
|
||||
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# PyTorch reference
|
||||
# PyTorch reference (un-normalized)
|
||||
qf = q[0].float()
|
||||
kf = k.float()
|
||||
vf = v.float()
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
attn = qf @ kf.T * scale
|
||||
attn_max = attn.max(dim=-1, keepdim=True)[0]
|
||||
attn_exp = torch.exp(attn - attn_max)
|
||||
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
|
||||
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
|
||||
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
|
||||
ref = ((attn_exp / attn_sum) @ vf).unsqueeze(0)
|
||||
ref_unnorm = attn_exp @ vf
|
||||
ref_norm = (attn_exp / attn_sum) @ vf
|
||||
|
||||
out = dsv4_attention(q, k, v)
|
||||
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
|
||||
cos_unnorm = torch.nn.functional.cosine_similarity(
|
||||
out.flatten().unsqueeze(0), ref_unnorm.unsqueeze(0).flatten().unsqueeze(0)
|
||||
).item()
|
||||
status = "PASS" if cos >= 0.99 else "FAIL"
|
||||
print(f" hd={hd}, n_h={n_h}, N={N}: cos {cos:.6f} {status}")
|
||||
cos_norm = torch.nn.functional.cosine_similarity(
|
||||
out.flatten().unsqueeze(0), ref_norm.unsqueeze(0).flatten().unsqueeze(0)
|
||||
).item()
|
||||
print(f" hd={hd}, n_h={n_h}, N={N}: cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f}")
|
||||
|
||||
|
||||
def test_production_multi_head():
|
||||
|
||||
Reference in New Issue
Block a user