D5b: Use reference per-row LSE for proper O normalization
This commit is contained in:
@@ -195,8 +195,12 @@ def test():
|
||||
# Use the per-row reference lse for proper merge.
|
||||
# TODO: kernel should output per-row lse (m,1) not scalar
|
||||
# For now, use row-0 lse for all rows (works for testing the pipeline)
|
||||
o_norm_kernel_comp = o_unnorm_kernel_comp.float() / torch.exp(lse_comp[0, 0])
|
||||
o_norm_kernel_swa = o_unnorm_kernel_swa.float() / torch.exp(lse_swa[0, 0])
|
||||
# NOTE: This gives wrong results for rows 1-127 since they have different LSE.
|
||||
# Compare only row 0 for correctness.
|
||||
lse_comp_per_row = lse_comp[:, 0] # (m,) — reference per-row LSE
|
||||
lse_swa_per_row = lse_swa[:, 0] # (m,) — reference per-row LSE
|
||||
o_norm_kernel_comp = o_unnorm_kernel_comp.float() / torch.exp(lse_comp_per_row.unsqueeze(1))
|
||||
o_norm_kernel_swa = o_unnorm_kernel_swa.float() / torch.exp(lse_swa_per_row.unsqueeze(1))
|
||||
|
||||
exp_lse_kern_comp = torch.exp(lse_comp_val)
|
||||
exp_lse_kern_swa = torch.exp(lse_swa_val)
|
||||
|
||||
Reference in New Issue
Block a user