diff --git a/single_shot_inference.py b/single_shot_inference.py index dec2b412..2b7faba7 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -540,12 +540,16 @@ def moe_forward(x, w, li, cfg, token_id, device): is_hash = li < 3 # ---- Routing ---- + expert_ids = None + expert_weights = None + if is_hash: tid2eid_key = f"model.layers.{li}.mlp.gate.tid2eid" if tid2eid_key in w: tid2eid = w[tid2eid_key] tid = token_id.item() if token_id.numel() == 1 else token_id[0].item() - expert_ids = tid2eid[tid] # (top_k,) + expert_ids = tid2eid[tid] # (top_k,) int64 + expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k else: # Fallback: use dense routing even for hash layers is_hash = False @@ -562,8 +566,7 @@ def moe_forward(x, w, li, cfg, token_id, device): activated = activated + w[e_bias_key].float().unsqueeze(0) # Top-k scores, indices = activated.topk(top_k, dim=-1) # (T, top_k) - # Renormalize on UNBIASED activation - # Re-compute unbiased activation for weights + # Renormalize on UNBIASED activation (no e_bias in weights) unbiased = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6) unbiased_scores = torch.gather(unbiased, -1, indices) expert_weights = unbiased_scores / unbiased_scores.sum(dim=-1, keepdim=True) @@ -572,7 +575,6 @@ def moe_forward(x, w, li, cfg, token_id, device): expert_ids = indices[0] expert_weights = expert_weights[0] else: - # Per-token routing (not yet needed for decode) raise NotImplementedError("Multi-token MoE routing") # ---- Run selected experts ----