Use exact merge formula from working test_d1_kv_merge.py

This commit is contained in:
2026-05-27 06:38:04 +00:00
parent 8321ccf9c1
commit ddc701af9b

View File

@@ -211,29 +211,16 @@ def _attention_single_head(
if nt == 0:
seg_lse[:, 0] = lse_tensor[:, 0, 0].float()
# Merge with accumulator using log-sum-exp
# O_i from kernel is UN-normalized. We need normalized O for merge:
# O_norm = O_unnorm / row_sum
# Merge: O_merged = sum(exp(lse_i) * O_i_norm) / sum(exp(lse_i))
#
# We track O_accum in NORMALIZED form and merge incrementally.
#
# Special case: if this is the first segment (lse_accum=-inf),
# just set O_accum = O_norm for this segment.
if seg == 0:
# First segment: normalize directly
row_sum = torch.exp(seg_lse).clamp(min=1e-30)
o_accum = seg_o / row_sum
lse_accum = seg_lse
else:
# Merge: O_new = (e_old * O_old + e_new * O_new_norm) / (e_old + e_new)
e_old = torch.exp(lse_accum)
e_new = torch.exp(seg_lse)
e_sum = e_old + e_new
row_sum_new = e_new.clamp(min=1e-30)
o_new_norm = seg_o / row_sum_new
o_accum = (e_old * o_accum + e_new * o_new_norm) / e_sum
lse_accum = torch.log(e_sum)
# Merge with accumulator using log-sum-exp (same as test_d1_kv_merge.py)
# Formula: O = (exp(lse_old) * O_old + exp(lse_new) * O_new) / (exp(lse_old) + exp(lse_new))
# Both O_old and O_new are un-normalized outputs from the kernel.
# This produces the correct normalized result after merge.
e_old = torch.exp(lse_accum)
e_new = torch.exp(seg_lse)
e_sum = e_old + e_new
o_accum = (e_old * o_accum + e_new * seg_o) / e_sum
lse_accum = torch.log(e_sum)
# o_accum is already normalized (in the merge above)
output = o_accum.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)