diff --git a/single_shot_inference.py b/single_shot_inference.py index e3503ea2..57a30d0b 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -393,8 +393,9 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, # Correction: attn_exp / (attn_exp + sink_exp) correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10) # (T, n_h) # Apply per-head correction - attn_out = attn_out * correction.unsqueeze(-1) # (T, n_h, hd) * (T, n_h, 1) + attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() # (T, n_h, hd) + attn_out = attn_out.bfloat16() # -- Inverse RoPE on attention output (paper ยง2.3.3) -- attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)