fix MoE gate BF16/NVFP4 handling, add attention diagnostics
This commit is contained in:
@@ -438,6 +438,10 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
|
||||
# 1. Q projection: q_a → q_a_norm → q_b → q_b_norm
|
||||
q_a = do_nvfp4_linear(x_normed, w, pfx, 'q_a_proj')
|
||||
if q_a is None:
|
||||
print(f" WARNING L{li}: q_a_proj not found, keys: {[k for k in w if 'q_a' in k and f'layers.{li}' in k][:5]}")
|
||||
return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), None
|
||||
if VERBOSE >= 2: print(f" L{li} q_a: |max|={q_a.abs().max().item():.4f} shape={q_a.shape}")
|
||||
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
|
||||
if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32))
|
||||
q = do_nvfp4_linear(q_a, w, pfx, 'q_b_proj')
|
||||
@@ -447,6 +451,9 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
|
||||
# 2. KV projection (MQA, single KV head, hd dim)
|
||||
kv = do_nvfp4_linear(x_normed, w, pfx, 'kv_proj')
|
||||
if kv is None:
|
||||
print(f" WARNING L{li}: kv_proj not found, keys: {[k for k in w if 'kv_proj' in k and f'layers.{li}' in k][:5]}")
|
||||
return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a
|
||||
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
||||
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
||||
kv_3d = kv.reshape(T, 1, hd)
|
||||
@@ -551,14 +558,17 @@ def moe_forward(x, w, li, cfg, token_id, device):
|
||||
expert_ids = tid2eid[tid]
|
||||
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
|
||||
else:
|
||||
# Gate weight may be BF16 or NVFP4
|
||||
gate_ww, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate')
|
||||
if gate_ww is not None:
|
||||
if gate_ww is not None and gate_ws is not None:
|
||||
logits = nvfp4_linear(x, gate_ww.to(device), gate_ws.to(device),
|
||||
gate_ws2.to(device) if gate_ws2 is not None else None,
|
||||
gate_isc.to(device) if gate_isc is not None else None)
|
||||
elif f"{pfx}.gate.weight" in w:
|
||||
gw = w[f"{pfx}.gate.weight"].bfloat16().to(device)
|
||||
logits = F.linear(x, gw)
|
||||
else:
|
||||
gw = w.get(f"{pfx}.gate.weight")
|
||||
logits = F.linear(x, gw.bfloat16().to(device))
|
||||
raise ValueError(f"No gate weight for layer {li}")
|
||||
scores = torch.sqrt(F.softplus(logits.float()) + 1e-6)
|
||||
sel = scores.clone()
|
||||
if e_bias_key in w:
|
||||
|
||||
Reference in New Issue
Block a user