fix: expert_weights/ids scoping in hash routing path

This commit is contained in:
2026-05-31 02:50:32 +00:00
parent d772885d7e
commit 61160ace13

View File

@@ -540,12 +540,16 @@ def moe_forward(x, w, li, cfg, token_id, device):
is_hash = li < 3
# ---- Routing ----
expert_ids = None
expert_weights = None
if is_hash:
tid2eid_key = f"model.layers.{li}.mlp.gate.tid2eid"
if tid2eid_key in w:
tid2eid = w[tid2eid_key]
tid = token_id.item() if token_id.numel() == 1 else token_id[0].item()
expert_ids = tid2eid[tid] # (top_k,)
expert_ids = tid2eid[tid] # (top_k,) int64
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
else:
# Fallback: use dense routing even for hash layers
is_hash = False
@@ -562,8 +566,7 @@ def moe_forward(x, w, li, cfg, token_id, device):
activated = activated + w[e_bias_key].float().unsqueeze(0)
# Top-k
scores, indices = activated.topk(top_k, dim=-1) # (T, top_k)
# Renormalize on UNBIASED activation
# Re-compute unbiased activation for weights
# Renormalize on UNBIASED activation (no e_bias in weights)
unbiased = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6)
unbiased_scores = torch.gather(unbiased, -1, indices)
expert_weights = unbiased_scores / unbiased_scores.sum(dim=-1, keepdim=True)
@@ -572,7 +575,6 @@ def moe_forward(x, w, li, cfg, token_id, device):
expert_ids = indices[0]
expert_weights = expert_weights[0]
else:
# Per-token routing (not yet needed for decode)
raise NotImplementedError("Multi-token MoE routing")
# ---- Run selected experts ----