From 4e64acbb648f2969f30b54415a573e84c9d6034b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 21:57:47 +0000 Subject: [PATCH] fix MoE gate BF16/NVFP4 handling, add attention diagnostics --- single_shot_inference.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 0661f8e2..c85465a5 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -438,6 +438,10 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, # 1. Q projection: q_a → q_a_norm → q_b → q_b_norm q_a = do_nvfp4_linear(x_normed, w, pfx, 'q_a_proj') + if q_a is None: + print(f" WARNING L{li}: q_a_proj not found, keys: {[k for k in w if 'q_a' in k and f'layers.{li}' in k][:5]}") + return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), None + if VERBOSE >= 2: print(f" L{li} q_a: |max|={q_a.abs().max().item():.4f} shape={q_a.shape}") q_norm_w = w.get(f"{pfx}.q_a_norm.weight") if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32)) q = do_nvfp4_linear(q_a, w, pfx, 'q_b_proj') @@ -447,6 +451,9 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, # 2. KV projection (MQA, single KV head, hd dim) kv = do_nvfp4_linear(x_normed, w, pfx, 'kv_proj') + if kv is None: + print(f" WARNING L{li}: kv_proj not found, keys: {[k for k in w if 'kv_proj' in k and f'layers.{li}' in k][:5]}") + return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a kv_norm_w = w.get(f"{pfx}.kv_norm.weight") if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32)) kv_3d = kv.reshape(T, 1, hd) @@ -551,14 +558,17 @@ def moe_forward(x, w, li, cfg, token_id, device): expert_ids = tid2eid[tid] expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k else: + # Gate weight may be BF16 or NVFP4 gate_ww, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate') - if gate_ww is not None: + if gate_ww is not None and gate_ws is not None: logits = nvfp4_linear(x, gate_ww.to(device), gate_ws.to(device), gate_ws2.to(device) if gate_ws2 is not None else None, gate_isc.to(device) if gate_isc is not None else None) + elif f"{pfx}.gate.weight" in w: + gw = w[f"{pfx}.gate.weight"].bfloat16().to(device) + logits = F.linear(x, gw) else: - gw = w.get(f"{pfx}.gate.weight") - logits = F.linear(x, gw.bfloat16().to(device)) + raise ValueError(f"No gate weight for layer {li}") scores = torch.sqrt(F.softplus(logits.float()) + 1e-6) sel = scores.clone() if e_bias_key in w: