From 06cb8002422341e4d36d7e2d31c92a20c1dbd2f7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 25 May 2026 17:06:21 +0000 Subject: [PATCH] fix regression test: use normalize=False + external LSE normalization --- tests/unit/test_d2_regression.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_d2_regression.py b/tests/unit/test_d2_regression.py index af9ad33e..938ee5dc 100644 --- a/tests/unit/test_d2_regression.py +++ b/tests/unit/test_d2_regression.py @@ -18,14 +18,20 @@ def test(): v = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') o = torch.zeros(M, hd, 1, dtype=torch.bfloat16, device='cuda') - fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True) + fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=False) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) v_c = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v)) o_c = ct.from_dlpack(o).mark_layout_dynamic(leading_dim=ct.get_leading_dim(o)) - fmha(q_c, k_c, v_c, o_c, stream) + lse = torch.zeros(M, dtype=torch.float32, device='cuda') + lse_c = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse)) + fmha(q_c, k_c, v_c, o_c, stream, lse_c) + + # External normalization using LSE + row_sum = lse.exp() + o_norm = o[:,:,0] / row_sum.unsqueeze(-1) # Reference scores = torch.matmul(q[:,:,0].float(), k[:,:,0].float().T) * scale @@ -36,9 +42,9 @@ def test(): ref = torch.matmul(p, v[:,:,0].float()).to(torch.bfloat16) cos = torch.nn.functional.cosine_similarity( - o[:,:,0].flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0) + o_norm.flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0) ).item() - print(f" cos = {cos:.6f}") + print(f" cos (ext norm) = {cos:.6f}") if cos >= 0.99: print(" ✅ PASS") else: