Use exact merge formula from working test_d1_kv_merge.py
This commit is contained in:
@@ -211,29 +211,16 @@ def _attention_single_head(
|
||||
if nt == 0:
|
||||
seg_lse[:, 0] = lse_tensor[:, 0, 0].float()
|
||||
|
||||
# Merge with accumulator using log-sum-exp
|
||||
# O_i from kernel is UN-normalized. We need normalized O for merge:
|
||||
# O_norm = O_unnorm / row_sum
|
||||
# Merge: O_merged = sum(exp(lse_i) * O_i_norm) / sum(exp(lse_i))
|
||||
#
|
||||
# We track O_accum in NORMALIZED form and merge incrementally.
|
||||
#
|
||||
# Special case: if this is the first segment (lse_accum=-inf),
|
||||
# just set O_accum = O_norm for this segment.
|
||||
if seg == 0:
|
||||
# First segment: normalize directly
|
||||
row_sum = torch.exp(seg_lse).clamp(min=1e-30)
|
||||
o_accum = seg_o / row_sum
|
||||
lse_accum = seg_lse
|
||||
else:
|
||||
# Merge: O_new = (e_old * O_old + e_new * O_new_norm) / (e_old + e_new)
|
||||
e_old = torch.exp(lse_accum)
|
||||
e_new = torch.exp(seg_lse)
|
||||
e_sum = e_old + e_new
|
||||
row_sum_new = e_new.clamp(min=1e-30)
|
||||
o_new_norm = seg_o / row_sum_new
|
||||
o_accum = (e_old * o_accum + e_new * o_new_norm) / e_sum
|
||||
lse_accum = torch.log(e_sum)
|
||||
# Merge with accumulator using log-sum-exp (same as test_d1_kv_merge.py)
|
||||
# Formula: O = (exp(lse_old) * O_old + exp(lse_new) * O_new) / (exp(lse_old) + exp(lse_new))
|
||||
# Both O_old and O_new are un-normalized outputs from the kernel.
|
||||
# This produces the correct normalized result after merge.
|
||||
e_old = torch.exp(lse_accum)
|
||||
e_new = torch.exp(seg_lse)
|
||||
e_sum = e_old + e_new
|
||||
|
||||
o_accum = (e_old * o_accum + e_new * seg_o) / e_sum
|
||||
lse_accum = torch.log(e_sum)
|
||||
|
||||
# o_accum is already normalized (in the merge above)
|
||||
output = o_accum.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)
|
||||
|
||||
Reference in New Issue
Block a user