Fix production KV merge: use normalized O for log-sum-exp merge

This commit is contained in:
2026-05-27 06:36:24 +00:00
parent 98c93c1cd8
commit 8321ccf9c1

View File

@@ -212,12 +212,29 @@ def _attention_single_head(
seg_lse[:, 0] = lse_tensor[:, 0, 0].float()
# Merge with accumulator using log-sum-exp
e_old = torch.exp(lse_accum)
e_new = torch.exp(seg_lse)
e_sum = e_old + e_new
o_accum = (e_old * o_accum + e_new * seg_o) / e_sum
lse_accum = torch.log(e_sum)
# O_i from kernel is UN-normalized. We need normalized O for merge:
# O_norm = O_unnorm / row_sum
# Merge: O_merged = sum(exp(lse_i) * O_i_norm) / sum(exp(lse_i))
#
# We track O_accum in NORMALIZED form and merge incrementally.
#
# Special case: if this is the first segment (lse_accum=-inf),
# just set O_accum = O_norm for this segment.
if seg == 0:
# First segment: normalize directly
row_sum = torch.exp(seg_lse).clamp(min=1e-30)
o_accum = seg_o / row_sum
lse_accum = seg_lse
else:
# Merge: O_new = (e_old * O_old + e_new * O_new_norm) / (e_old + e_new)
e_old = torch.exp(lse_accum)
e_new = torch.exp(seg_lse)
e_sum = e_old + e_new
row_sum_new = e_new.clamp(min=1e-30)
o_new_norm = seg_o / row_sum_new
o_accum = (e_old * o_accum + e_new * o_new_norm) / e_sum
lse_accum = torch.log(e_sum)
# Normalize
# o_accum is already normalized (in the merge above)
output = o_accum.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)
return output