diff --git a/tests/unit/test_d5c_fused.py b/tests/unit/test_d5c_fused.py index 5a12760b..42ecf58e 100644 --- a/tests/unit/test_d5c_fused.py +++ b/tests/unit/test_d5c_fused.py @@ -190,19 +190,27 @@ def test_d5c_combined(): torch.cuda.synchronize() # Check results - o_kernel = c_out[:, :, 0].float() + # Kernel outputs UN-NORMALIZED O (normalize=False). Normalize using per-row LSE. + # O_norm[i] = O_unnorm[i] * exp(-lse[i]) + o_kernel_unnorm = c_out[:, :, 0].float() # (m, hd) + lse_kernel = lse_out[:, 0, 0].float() # (m,) + + # Normalize each row + o_kernel = o_kernel_unnorm * (-lse_kernel.unsqueeze(1)).exp() # (m, hd) + cos = torch.nn.functional.cosine_similarity( o_kernel.flatten().unsqueeze(0), ref_combined.flatten().unsqueeze(0).float() ).item() max_abs = (o_kernel - ref_combined.float()).abs().max().item() - status = "PASS" if cos >= 0.95 else "FAIL" + status = "PASS" if cos >= 0.99 else "FAIL" print(f'\nD5c result: cos {cos:.6f} max_abs {max_abs:.4f} {status}') - if cos < 0.95: + if cos < 0.99: print(f' kernel[0,:4]={o_kernel[0,:4].tolist()}') print(f' ref[0,:4]={ref_combined[0,:4].tolist()}') + print(f' LSE range: {lse_kernel.min().item():.4f} to {lse_kernel.max().item():.4f}') def test_d5c_with_causal(): @@ -272,14 +280,17 @@ def test_d5c_with_causal(): ) torch.cuda.synchronize() - o_kernel = c_out[:, :, 0].float() + o_kernel_unnorm = c_out[:, :, 0].float() + lse_kernel = lse_out[:, 0, 0].float() + o_kernel = o_kernel_unnorm * (-lse_kernel.unsqueeze(1)).exp() + cos = torch.nn.functional.cosine_similarity( o_kernel.flatten().unsqueeze(0), ref.flatten().unsqueeze(0).float() ).item() max_abs = (o_kernel - ref.float()).abs().max().item() - status = "PASS" if cos >= 0.95 else "FAIL" + status = "PASS" if cos >= 0.99 else "FAIL" print(f'D5c causal result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')