Debug: print LSE values for 2-segment merge
This commit is contained in:
@@ -195,5 +195,12 @@ def _attention_single_head(
|
||||
o_accum = (e_old * o_accum + e_new * seg_o) / e_sum
|
||||
lse_accum = torch.log(e_sum)
|
||||
|
||||
# Debug: check LSE values
|
||||
if seg == 0:
|
||||
print(f' seg 0: lse[0]={seg_lse[0,0].item():.4f}, o[0,0]={seg_o[0,0].item():.4f}')
|
||||
elif seg == 1:
|
||||
print(f' seg 1: lse[0]={seg_lse[0,0].item():.4f}, o[0,0]={seg_o[0,0].item():.4f}')
|
||||
print(f' merged: lse[0]={lse_accum[0,0].item():.4f}, o[0,0]={o_accum[0,0].item():.4f}')
|
||||
|
||||
output = o_accum.to(torch.bfloat16).unsqueeze(0)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user