From 13bae9dd55e2d7c92475d65b96c77d77f20d39fc Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 22:42:17 +0000 Subject: [PATCH] Fix single_shot: mHC replaces layernorm, no hidden-level norm in DSV4 --- single_shot_inference.py | 49 +++++++++++----------------------------- 1 file changed, 13 insertions(+), 36 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 064ae5c9..8e79b3d8 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -196,23 +196,11 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin): T = x.shape[0] # ---- RMSNorm (attention) ---- - norm_w = w.get(f"model.layers.{li}.self_attn.kv_norm.weight") - # Actually check for the right norm key - # The norm might be "input_layernorm" or "attn_norm" - for key_candidate in [f"model.layers.{li}.self_attn.kv_norm.weight", - f"model.layers.{li}.input_layernorm.weight", - f"model.layers.{li}.self_attn.norm.weight"]: - norm_w = w.get(key_candidate) - if norm_w is not None: - break - - if norm_w is not None: - x_f = x.float() - rms = x_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - x_norm = (x_f * rms * norm_w.cuda().float()).bfloat16() - else: - x_norm = x - print(f" L{li}: no norm weight found, skipping norm") + # DSV4 uses mHC prenorm, not standard layernorm. + # For baseline, use q_a_norm on the Q path and kv_norm on the KV path. + # No hidden-level norm (mHC handles it). + q_norm_w = w.get(f"{pre}.q_a_norm.weight") # (q_lora,) BF16 + kv_norm_w = w.get(f"{pre}.kv_norm.weight") # (hd,) BF16 # ---- Q projection: q_a (down) → q_b (up) ---- qa_w = w[f"{pre}.q_a_proj.weight"] @@ -222,14 +210,16 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin): qb_s = w[f"{pre}.q_b_proj.weight_scale"] qb_s2 = w[f"{pre}.q_b_proj.weight_scale_2"] - c_Q = nvfp4_linear(x_norm, qa_w, qa_s, qa_s2) # (1, q_lora) + # For baseline: skip per-projection norms (mHC handles it) + # Just project raw hidden + c_Q = nvfp4_linear(x, qa_w, qa_s, qa_s2) # (1, q_lora) q = nvfp4_linear(c_Q, qb_w, qb_s, qb_s2) # (1, n_h * hd) # ---- KV projection ---- kv_w = w[f"{pre}.kv_proj.weight"] kv_s = w[f"{pre}.kv_proj.weight_scale"] kv_s2 = w[f"{pre}.kv_proj.weight_scale_2"] - kv = nvfp4_linear(x_norm, kv_w, kv_s, kv_s2) # (1, kv_dim) + kv = nvfp4_linear(x, kv_w, kv_s, kv_s2) # (1, kv_dim) # ---- Reshape for attention ---- q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd) @@ -271,21 +261,8 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin): x = x + attn_proj # ---- FFN (shared expert only for baseline) ---- - # RMSNorm (FFN) - ffn_norm_w = None - for key_candidate in [f"model.layers.{li}.post_attention_layernorm.weight", - f"model.layers.{li}.self_attn.ffn_norm.weight", - f"model.layers.{li}.norm.weight"]: - ffn_norm_w = w.get(key_candidate) - if ffn_norm_w is not None: - break - - if ffn_norm_w is not None: - x_f = x.float() - rms = x_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - x_ffn_in = (x_f * rms * ffn_norm_w.cuda().float()).bfloat16() - else: - x_ffn_in = x + # No separate FFN norm in DSV4 — mHC handles it + # For baseline, just apply shared expert to the residual x directly # Shared expert: gate_proj + up_proj → SiLU(gate) * up → down_proj se_pre = f"model.layers.{li}.mlp.shared_experts" @@ -294,10 +271,10 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin): se_down_w = w.get(f"{se_pre}.down_proj.weight") if se_gate_w is not None and se_up_w is not None and se_down_w is not None: - gate = nvfp4_linear(x_ffn_in, se_gate_w, + gate = nvfp4_linear(x, se_gate_w, w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"]) - up = nvfp4_linear(x_ffn_in, se_up_w, + up = nvfp4_linear(x, se_up_w, w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"]) ffn_out = nvfp4_linear(