fix MoE gate BF16/NVFP4 handling, add attention diagnostics

This commit is contained in:
2026-05-31 21:57:47 +00:00
parent 0d2b5ceb93
commit 4e64acbb64

View File

@@ -438,6 +438,10 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
# 1. Q projection: q_a → q_a_norm → q_b → q_b_norm
q_a = do_nvfp4_linear(x_normed, w, pfx, 'q_a_proj')
if q_a is None:
print(f" WARNING L{li}: q_a_proj not found, keys: {[k for k in w if 'q_a' in k and f'layers.{li}' in k][:5]}")
return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), None
if VERBOSE >= 2: print(f" L{li} q_a: |max|={q_a.abs().max().item():.4f} shape={q_a.shape}")
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32))
q = do_nvfp4_linear(q_a, w, pfx, 'q_b_proj')
@@ -447,6 +451,9 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
# 2. KV projection (MQA, single KV head, hd dim)
kv = do_nvfp4_linear(x_normed, w, pfx, 'kv_proj')
if kv is None:
print(f" WARNING L{li}: kv_proj not found, keys: {[k for k in w if 'kv_proj' in k and f'layers.{li}' in k][:5]}")
return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
kv_3d = kv.reshape(T, 1, hd)
@@ -551,14 +558,17 @@ def moe_forward(x, w, li, cfg, token_id, device):
expert_ids = tid2eid[tid]
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
else:
# Gate weight may be BF16 or NVFP4
gate_ww, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate')
if gate_ww is not None:
if gate_ww is not None and gate_ws is not None:
logits = nvfp4_linear(x, gate_ww.to(device), gate_ws.to(device),
gate_ws2.to(device) if gate_ws2 is not None else None,
gate_isc.to(device) if gate_isc is not None else None)
elif f"{pfx}.gate.weight" in w:
gw = w[f"{pfx}.gate.weight"].bfloat16().to(device)
logits = F.linear(x, gw)
else:
gw = w.get(f"{pfx}.gate.weight")
logits = F.linear(x, gw.bfloat16().to(device))
raise ValueError(f"No gate weight for layer {li}")
scores = torch.sqrt(F.softplus(logits.float()) + 1e-6)
sel = scores.clone()
if e_bias_key in w: