Return o_accum directly (un-normalized merge result)
This commit is contained in:
@@ -222,9 +222,5 @@ def _attention_single_head(
|
||||
o_accum = (e_old * o_accum + e_new * seg_o) / e_sum
|
||||
lse_accum = torch.log(e_sum)
|
||||
|
||||
# o_accum is the LSE-merged un-normalized O. Normalize by the final LSE.
|
||||
# O_norm = O_unnorm / row_sum, where row_sum = exp(lse)
|
||||
row_sum = torch.exp(lse_accum).clamp(min=1e-30)
|
||||
o_norm = o_accum / row_sum
|
||||
output = o_norm.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)
|
||||
output = o_accum.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user