[Bugfix] Fix MTP accuracy for GLM-5 (#34385)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-02-11 22:08:19 -05:00
committed by GitHub
parent ff1f83b056
commit ec12d39d44

View File

@@ -1506,6 +1506,24 @@ class SpecDecodeBaseProposer:
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
# MTP models call compute_logits via shared_head.head (a
# ParallelLMHead inside each MTP layer), not self.model.lm_head.
# If the checkpoint omits a copy of the lm_head weights at the
# MTP layer path, shared_head.head stays uninitialised and
# produces NaN logits. Always share it explicitly.
inner = getattr(self.model, "model", None)
layers = getattr(inner, "layers", None) if inner else None
if layers is not None:
items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
for layer in items:
sh = getattr(layer, "shared_head", None)
if sh is not None and hasattr(sh, "head"):
del sh.head
sh.head = target_language_model.lm_head
logger.info(
"Shared target model lm_head with MTP shared_head.head."
)
@torch.inference_mode()
def dummy_run(
self,