fix: cast attn_out back to BF16 after sink correction

This commit is contained in:
2026-05-31 06:07:06 +00:00
parent e5245ea34e
commit 59c75ca4e9

View File

@@ -393,8 +393,9 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
# Correction: attn_exp / (attn_exp + sink_exp)
correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10) # (T, n_h)
# Apply per-head correction
attn_out = attn_out * correction.unsqueeze(-1) # (T, n_h, hd) * (T, n_h, 1)
attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() # (T, n_h, hd)
attn_out = attn_out.bfloat16()
# -- Inverse RoPE on attention output (paper §2.3.3) --
attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)