From ddc701af9b451e84df6a9f6b4d64caf4ce92bd52 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 27 May 2026 06:38:04 +0000 Subject: [PATCH] Use exact merge formula from working test_d1_kv_merge.py --- dsv4/kernels/attention/production.py | 33 +++++++++------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/dsv4/kernels/attention/production.py b/dsv4/kernels/attention/production.py index 34279be3..f33c7e03 100644 --- a/dsv4/kernels/attention/production.py +++ b/dsv4/kernels/attention/production.py @@ -211,29 +211,16 @@ def _attention_single_head( if nt == 0: seg_lse[:, 0] = lse_tensor[:, 0, 0].float() - # Merge with accumulator using log-sum-exp - # O_i from kernel is UN-normalized. We need normalized O for merge: - # O_norm = O_unnorm / row_sum - # Merge: O_merged = sum(exp(lse_i) * O_i_norm) / sum(exp(lse_i)) - # - # We track O_accum in NORMALIZED form and merge incrementally. - # - # Special case: if this is the first segment (lse_accum=-inf), - # just set O_accum = O_norm for this segment. - if seg == 0: - # First segment: normalize directly - row_sum = torch.exp(seg_lse).clamp(min=1e-30) - o_accum = seg_o / row_sum - lse_accum = seg_lse - else: - # Merge: O_new = (e_old * O_old + e_new * O_new_norm) / (e_old + e_new) - e_old = torch.exp(lse_accum) - e_new = torch.exp(seg_lse) - e_sum = e_old + e_new - row_sum_new = e_new.clamp(min=1e-30) - o_new_norm = seg_o / row_sum_new - o_accum = (e_old * o_accum + e_new * o_new_norm) / e_sum - lse_accum = torch.log(e_sum) + # Merge with accumulator using log-sum-exp (same as test_d1_kv_merge.py) + # Formula: O = (exp(lse_old) * O_old + exp(lse_new) * O_new) / (exp(lse_old) + exp(lse_new)) + # Both O_old and O_new are un-normalized outputs from the kernel. + # This produces the correct normalized result after merge. + e_old = torch.exp(lse_accum) + e_new = torch.exp(seg_lse) + e_sum = e_old + e_new + + o_accum = (e_old * o_accum + e_new * seg_o) / e_sum + lse_accum = torch.log(e_sum) # o_accum is already normalized (in the merge above) output = o_accum.to(torch.bfloat16).unsqueeze(0) # (1, T, hd)