fix regression test: use normalize=False + external LSE normalization
This commit is contained in:
@@ -18,14 +18,20 @@ def test():
|
||||
v = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
o = torch.zeros(M, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True)
|
||||
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=False)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
v_c = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
o_c = ct.from_dlpack(o).mark_layout_dynamic(leading_dim=ct.get_leading_dim(o))
|
||||
fmha(q_c, k_c, v_c, o_c, stream)
|
||||
lse = torch.zeros(M, dtype=torch.float32, device='cuda')
|
||||
lse_c = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
|
||||
fmha(q_c, k_c, v_c, o_c, stream, lse_c)
|
||||
|
||||
# External normalization using LSE
|
||||
row_sum = lse.exp()
|
||||
o_norm = o[:,:,0] / row_sum.unsqueeze(-1)
|
||||
|
||||
# Reference
|
||||
scores = torch.matmul(q[:,:,0].float(), k[:,:,0].float().T) * scale
|
||||
@@ -36,9 +42,9 @@ def test():
|
||||
ref = torch.matmul(p, v[:,:,0].float()).to(torch.bfloat16)
|
||||
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
o[:,:,0].flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0)
|
||||
o_norm.flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0)
|
||||
).item()
|
||||
print(f" cos = {cos:.6f}")
|
||||
print(f" cos (ext norm) = {cos:.6f}")
|
||||
if cos >= 0.99:
|
||||
print(" ✅ PASS")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user