diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b5532d652..a6e7995bc 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1506,6 +1506,24 @@ class SpecDecodeBaseProposer: del self.model.lm_head self.model.lm_head = target_language_model.lm_head + # MTP models call compute_logits via shared_head.head (a + # ParallelLMHead inside each MTP layer), not self.model.lm_head. + # If the checkpoint omits a copy of the lm_head weights at the + # MTP layer path, shared_head.head stays uninitialised and + # produces NaN logits. Always share it explicitly. + inner = getattr(self.model, "model", None) + layers = getattr(inner, "layers", None) if inner else None + if layers is not None: + items = layers.values() if isinstance(layers, nn.ModuleDict) else layers + for layer in items: + sh = getattr(layer, "shared_head", None) + if sh is not None and hasattr(sh, "head"): + del sh.head + sh.head = target_language_model.lm_head + logger.info( + "Shared target model lm_head with MTP shared_head.head." + ) + @torch.inference_mode() def dummy_run( self,