fix regression test: use normalize=False + external LSE normalization

This commit is contained in:
2026-05-25 17:06:21 +00:00
parent 13b5afc471
commit 06cb800242

View File

@@ -18,14 +18,20 @@ def test():
v = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
o = torch.zeros(M, hd, 1, dtype=torch.bfloat16, device='cuda')
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True)
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=False)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
v_c = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
o_c = ct.from_dlpack(o).mark_layout_dynamic(leading_dim=ct.get_leading_dim(o))
fmha(q_c, k_c, v_c, o_c, stream)
lse = torch.zeros(M, dtype=torch.float32, device='cuda')
lse_c = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
fmha(q_c, k_c, v_c, o_c, stream, lse_c)
# External normalization using LSE
row_sum = lse.exp()
o_norm = o[:,:,0] / row_sum.unsqueeze(-1)
# Reference
scores = torch.matmul(q[:,:,0].float(), k[:,:,0].float().T) * scale
@@ -36,9 +42,9 @@ def test():
ref = torch.matmul(p, v[:,:,0].float()).to(torch.bfloat16)
cos = torch.nn.functional.cosine_similarity(
o[:,:,0].flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0)
o_norm.flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0)
).item()
print(f" cos = {cos:.6f}")
print(f" cos (ext norm) = {cos:.6f}")
if cos >= 0.99:
print(" ✅ PASS")
else: