fix: normalize kernel output using per-row LSE for D5c test

This commit is contained in:
2026-05-26 15:04:47 +00:00
parent 014d647ba3
commit 31e6426049

View File

@@ -190,19 +190,27 @@ def test_d5c_combined():
torch.cuda.synchronize()
# Check results
o_kernel = c_out[:, :, 0].float()
# Kernel outputs UN-NORMALIZED O (normalize=False). Normalize using per-row LSE.
# O_norm[i] = O_unnorm[i] * exp(-lse[i])
o_kernel_unnorm = c_out[:, :, 0].float() # (m, hd)
lse_kernel = lse_out[:, 0, 0].float() # (m,)
# Normalize each row
o_kernel = o_kernel_unnorm * (-lse_kernel.unsqueeze(1)).exp() # (m, hd)
cos = torch.nn.functional.cosine_similarity(
o_kernel.flatten().unsqueeze(0),
ref_combined.flatten().unsqueeze(0).float()
).item()
max_abs = (o_kernel - ref_combined.float()).abs().max().item()
status = "PASS" if cos >= 0.95 else "FAIL"
status = "PASS" if cos >= 0.99 else "FAIL"
print(f'\nD5c result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
if cos < 0.95:
if cos < 0.99:
print(f' kernel[0,:4]={o_kernel[0,:4].tolist()}')
print(f' ref[0,:4]={ref_combined[0,:4].tolist()}')
print(f' LSE range: {lse_kernel.min().item():.4f} to {lse_kernel.max().item():.4f}')
def test_d5c_with_causal():
@@ -272,14 +280,17 @@ def test_d5c_with_causal():
)
torch.cuda.synchronize()
o_kernel = c_out[:, :, 0].float()
o_kernel_unnorm = c_out[:, :, 0].float()
lse_kernel = lse_out[:, 0, 0].float()
o_kernel = o_kernel_unnorm * (-lse_kernel.unsqueeze(1)).exp()
cos = torch.nn.functional.cosine_similarity(
o_kernel.flatten().unsqueeze(0),
ref.flatten().unsqueeze(0).float()
).item()
max_abs = (o_kernel - ref.float()).abs().max().item()
status = "PASS" if cos >= 0.95 else "FAIL"
status = "PASS" if cos >= 0.99 else "FAIL"
print(f'D5c causal result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')