fix: cast attn_out back to BF16 after sink correction
This commit is contained in:
@@ -393,8 +393,9 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
# Correction: attn_exp / (attn_exp + sink_exp)
|
||||
correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10) # (T, n_h)
|
||||
# Apply per-head correction
|
||||
attn_out = attn_out * correction.unsqueeze(-1) # (T, n_h, hd) * (T, n_h, 1)
|
||||
attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() # (T, n_h, hd)
|
||||
|
||||
attn_out = attn_out.bfloat16()
|
||||
|
||||
# -- Inverse RoPE on attention output (paper §2.3.3) --
|
||||
attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)
|
||||
|
||||
Reference in New Issue
Block a user