Fix single_shot: mHC replaces layernorm, no hidden-level norm in DSV4
This commit is contained in:
@@ -196,23 +196,11 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||||
T = x.shape[0]
|
||||
|
||||
# ---- RMSNorm (attention) ----
|
||||
norm_w = w.get(f"model.layers.{li}.self_attn.kv_norm.weight")
|
||||
# Actually check for the right norm key
|
||||
# The norm might be "input_layernorm" or "attn_norm"
|
||||
for key_candidate in [f"model.layers.{li}.self_attn.kv_norm.weight",
|
||||
f"model.layers.{li}.input_layernorm.weight",
|
||||
f"model.layers.{li}.self_attn.norm.weight"]:
|
||||
norm_w = w.get(key_candidate)
|
||||
if norm_w is not None:
|
||||
break
|
||||
|
||||
if norm_w is not None:
|
||||
x_f = x.float()
|
||||
rms = x_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
x_norm = (x_f * rms * norm_w.cuda().float()).bfloat16()
|
||||
else:
|
||||
x_norm = x
|
||||
print(f" L{li}: no norm weight found, skipping norm")
|
||||
# DSV4 uses mHC prenorm, not standard layernorm.
|
||||
# For baseline, use q_a_norm on the Q path and kv_norm on the KV path.
|
||||
# No hidden-level norm (mHC handles it).
|
||||
q_norm_w = w.get(f"{pre}.q_a_norm.weight") # (q_lora,) BF16
|
||||
kv_norm_w = w.get(f"{pre}.kv_norm.weight") # (hd,) BF16
|
||||
|
||||
# ---- Q projection: q_a (down) → q_b (up) ----
|
||||
qa_w = w[f"{pre}.q_a_proj.weight"]
|
||||
@@ -222,14 +210,16 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||||
qb_s = w[f"{pre}.q_b_proj.weight_scale"]
|
||||
qb_s2 = w[f"{pre}.q_b_proj.weight_scale_2"]
|
||||
|
||||
c_Q = nvfp4_linear(x_norm, qa_w, qa_s, qa_s2) # (1, q_lora)
|
||||
# For baseline: skip per-projection norms (mHC handles it)
|
||||
# Just project raw hidden
|
||||
c_Q = nvfp4_linear(x, qa_w, qa_s, qa_s2) # (1, q_lora)
|
||||
q = nvfp4_linear(c_Q, qb_w, qb_s, qb_s2) # (1, n_h * hd)
|
||||
|
||||
# ---- KV projection ----
|
||||
kv_w = w[f"{pre}.kv_proj.weight"]
|
||||
kv_s = w[f"{pre}.kv_proj.weight_scale"]
|
||||
kv_s2 = w[f"{pre}.kv_proj.weight_scale_2"]
|
||||
kv = nvfp4_linear(x_norm, kv_w, kv_s, kv_s2) # (1, kv_dim)
|
||||
kv = nvfp4_linear(x, kv_w, kv_s, kv_s2) # (1, kv_dim)
|
||||
|
||||
# ---- Reshape for attention ----
|
||||
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
|
||||
@@ -271,21 +261,8 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||||
x = x + attn_proj
|
||||
|
||||
# ---- FFN (shared expert only for baseline) ----
|
||||
# RMSNorm (FFN)
|
||||
ffn_norm_w = None
|
||||
for key_candidate in [f"model.layers.{li}.post_attention_layernorm.weight",
|
||||
f"model.layers.{li}.self_attn.ffn_norm.weight",
|
||||
f"model.layers.{li}.norm.weight"]:
|
||||
ffn_norm_w = w.get(key_candidate)
|
||||
if ffn_norm_w is not None:
|
||||
break
|
||||
|
||||
if ffn_norm_w is not None:
|
||||
x_f = x.float()
|
||||
rms = x_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
x_ffn_in = (x_f * rms * ffn_norm_w.cuda().float()).bfloat16()
|
||||
else:
|
||||
x_ffn_in = x
|
||||
# No separate FFN norm in DSV4 — mHC handles it
|
||||
# For baseline, just apply shared expert to the residual x directly
|
||||
|
||||
# Shared expert: gate_proj + up_proj → SiLU(gate) * up → down_proj
|
||||
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
||||
@@ -294,10 +271,10 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||||
se_down_w = w.get(f"{se_pre}.down_proj.weight")
|
||||
|
||||
if se_gate_w is not None and se_up_w is not None and se_down_w is not None:
|
||||
gate = nvfp4_linear(x_ffn_in, se_gate_w,
|
||||
gate = nvfp4_linear(x, se_gate_w,
|
||||
w[f"{se_pre}.gate_proj.weight_scale"],
|
||||
w[f"{se_pre}.gate_proj.weight_scale_2"])
|
||||
up = nvfp4_linear(x_ffn_in, se_up_w,
|
||||
up = nvfp4_linear(x, se_up_w,
|
||||
w[f"{se_pre}.up_proj.weight_scale"],
|
||||
w[f"{se_pre}.up_proj.weight_scale_2"])
|
||||
ffn_out = nvfp4_linear(
|
||||
|
||||
Reference in New Issue
Block a user