Debug: print LSE values for 2-segment merge

This commit is contained in:
2026-05-27 07:04:39 +00:00
parent 8f8d14c300
commit 06e7f7ab48

View File

@@ -195,5 +195,12 @@ def _attention_single_head(
o_accum = (e_old * o_accum + e_new * seg_o) / e_sum
lse_accum = torch.log(e_sum)
# Debug: check LSE values
if seg == 0:
print(f' seg 0: lse[0]={seg_lse[0,0].item():.4f}, o[0,0]={seg_o[0,0].item():.4f}')
elif seg == 1:
print(f' seg 1: lse[0]={seg_lse[0,0].item():.4f}, o[0,0]={seg_o[0,0].item():.4f}')
print(f' merged: lse[0]={lse_accum[0,0].item():.4f}, o[0,0]={o_accum[0,0].item():.4f}')
output = o_accum.to(torch.bfloat16).unsqueeze(0)
return output