Return o_accum directly (un-normalized merge result)

This commit is contained in:
2026-05-27 06:42:58 +00:00
parent 6111db571c
commit b70ab2a6ee

View File

@@ -222,9 +222,5 @@ def _attention_single_head(
o_accum = (e_old * o_accum + e_new * seg_o) / e_sum
lse_accum = torch.log(e_sum)
# o_accum is the LSE-merged un-normalized O. Normalize by the final LSE.
# O_norm = O_unnorm / row_sum, where row_sum = exp(lse)
row_sum = torch.exp(lse_accum).clamp(min=1e-30)
o_norm = o_accum / row_sum
output = o_norm.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)
output = o_accum.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)
return output