From 8321ccf9c1d7ee3a1222a2cc3c35b3ff9dc175a6 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 27 May 2026 06:36:24 +0000 Subject: [PATCH] Fix production KV merge: use normalized O for log-sum-exp merge --- dsv4/kernels/attention/production.py | 29 ++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/dsv4/kernels/attention/production.py b/dsv4/kernels/attention/production.py index 438a8ae5..34279be3 100644 --- a/dsv4/kernels/attention/production.py +++ b/dsv4/kernels/attention/production.py @@ -212,12 +212,29 @@ def _attention_single_head( seg_lse[:, 0] = lse_tensor[:, 0, 0].float() # Merge with accumulator using log-sum-exp - e_old = torch.exp(lse_accum) - e_new = torch.exp(seg_lse) - e_sum = e_old + e_new - o_accum = (e_old * o_accum + e_new * seg_o) / e_sum - lse_accum = torch.log(e_sum) + # O_i from kernel is UN-normalized. We need normalized O for merge: + # O_norm = O_unnorm / row_sum + # Merge: O_merged = sum(exp(lse_i) * O_i_norm) / sum(exp(lse_i)) + # + # We track O_accum in NORMALIZED form and merge incrementally. + # + # Special case: if this is the first segment (lse_accum=-inf), + # just set O_accum = O_norm for this segment. + if seg == 0: + # First segment: normalize directly + row_sum = torch.exp(seg_lse).clamp(min=1e-30) + o_accum = seg_o / row_sum + lse_accum = seg_lse + else: + # Merge: O_new = (e_old * O_old + e_new * O_new_norm) / (e_old + e_new) + e_old = torch.exp(lse_accum) + e_new = torch.exp(seg_lse) + e_sum = e_old + e_new + row_sum_new = e_new.clamp(min=1e-30) + o_new_norm = seg_o / row_sum_new + o_accum = (e_old * o_accum + e_new * o_new_norm) / e_sum + lse_accum = torch.log(e_sum) - # Normalize + # o_accum is already normalized (in the merge above) output = o_accum.to(torch.bfloat16).unsqueeze(0) # (1, T, hd) return output