diff --git a/single_shot_inference.py b/single_shot_inference.py index bb6d4aac..b4949944 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -429,21 +429,33 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, # -- RMSNorm (pre-norm before attention) -- x_normed = attn_norm.forward(x_in) # (T, H) BF16 - # -- Q projection: q_a (low-rank down) → q_b (low-rank up) -- + # -- Q projection: q_a (low-rank down) → q_a_norm → q_b (low-rank up) -- c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"]) # (T, dc) + # Q norm (RMSNorm after q_a, before q_b) + q_norm_w = w.get(f"{pre}.q_a_norm.weight") + if q_norm_w is not None: + c_Q_f = c_Q.float() + c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16() q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"]) # (T, n_h * hd) - # -- KV projection (MQA: 1 KV head) -- + # -- KV projection (MQA: 1 KV head) + KV norm -- kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], - w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd) — 1 KV head, no split + w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd) + # KV norm (RMSNorm after kv_proj) + kv_norm_w = w.get(f"{pre}.kv_norm.weight") + if kv_norm_w is not None: + kv_f = kv.float() + kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16() # -- Reshape for attention -- q_heads = q.reshape(T, n_h, hd) # (T, n_h, hd) @@ -469,9 +481,8 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, # -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) -- from dsv4.kernels.attention.production import dsv4_attention q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd) - k_input = k_full.permute(1, 0, 2) # (1, seq_len, hd) — already RoPE'd - v_input = v_full.permute(1, 0, 2) # (1, seq_len, hd) — K=V, RoPE'd - attn_out = dsv4_attention(q_input, k_input, v_input) # (n_h, T, hd) + # k_full, v_full are (1, seq_len, hd) — already in (n_kv, N, hd) format + attn_out = dsv4_attention(q_input, k_full, v_full) # (n_h, T, hd) attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd) # -- Inverse RoPE on attention output (paper §2.3.3) --