From df10378bb58fdf043a5adede43b7d024372d3edf Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 15:13:16 +0000 Subject: [PATCH] D1.4: Fix regression test for un-normalized O output (D5a) --- tests/unit/test_d1_regression.py | 35 +++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_d1_regression.py b/tests/unit/test_d1_regression.py index 5a6258b3..bfa5ebf5 100644 --- a/tests/unit/test_d1_regression.py +++ b/tests/unit/test_d1_regression.py @@ -1,4 +1,5 @@ -"""Quick D1 regression test: HEAD_DIM=64 only, must match Stage C.""" +"""Quick D1 regression test: HEAD_DIM=64 only, must match Stage C. +Kernel outputs un-normalized O + LSE (D5a path).""" import torch, math import cutlass.cute as cute import cutlass.torch as ct @@ -15,31 +16,41 @@ def test(): v_kernel = v.unsqueeze(-1) c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + # FP32 reference (un-normalized + normalized) qf = q[:, :, 0].float() kf = k[:, :, 0].float() scale = 1.0 / math.sqrt(hd) - attn = qf @ kf.T * scale - attn = torch.softmax(attn, dim=-1) - ref = attn @ v.float() + attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(qf @ kf.T * scale - attn_max) + attn_sum = attn_exp.sum(dim=-1, keepdim=True) + ref_unnorm = attn_exp @ v.float() + ref_norm = (attn_exp / attn_sum) @ v.float() + + lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - kernel = FmhaKernel(head_dim=hd, s_k=n) + # normalize=False: kernel outputs un-normalized O + LSE + kernel = FmhaKernel(head_dim=hd, s_k=n, normalize=False) print(f'hd={hd}, n={n}: Compiling...', flush=True) - compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) - compiled(mQ, mK, mV, mC, stream) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) + compiled(mQ, mK, mV, mC, stream, mLSE) torch.cuda.synchronize() - out = c[:, :, 0].float() - cos = torch.nn.functional.cosine_similarity( - out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) + out_unnorm = c[:, :, 0].float() + out_norm = out_unnorm / attn_sum # external normalization using row_sum + cos_unnorm = torch.nn.functional.cosine_similarity( + out_unnorm.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0) ).item() - max_abs = (out - ref).abs().max().item() - print(f'hd={hd}, n={n}: cos {cos:.6f} max_abs {max_abs:.4f} {"PASS" if cos >= 0.97 else "FAIL"}') + cos_norm = torch.nn.functional.cosine_similarity( + out_norm.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0) + ).item() + print(f'hd={hd}, n={n}: cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f} {"PASS" if cos_norm >= 0.99 else "FAIL"}') if __name__ == '__main__': test()