From 59c75ca4e9ab24ce2f3545d45cda7c56b26a258a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 06:07:06 +0000 Subject: [PATCH] fix: cast attn_out back to BF16 after sink correction --- single_shot_inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index e3503ea2..57a30d0b 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -393,8 +393,9 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, # Correction: attn_exp / (attn_exp + sink_exp) correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10) # (T, n_h) # Apply per-head correction - attn_out = attn_out * correction.unsqueeze(-1) # (T, n_h, hd) * (T, n_h, 1) + attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() # (T, n_h, hd) + attn_out = attn_out.bfloat16() # -- Inverse RoPE on attention output (paper ยง2.3.3) -- attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)