[Bugfix] Fix MTP accuracy for GLM-5 (#34385)

Signed-off-by: mgoin <mgoin64@gmail.com>
(cherry picked from commit ec12d39d44)
This commit is contained in:
Michael Goin
2026-02-11 22:08:19 -05:00
committed by khluu
parent 946b2f106c
commit 7a06e5b05b

View File

@@ -1503,6 +1503,24 @@ class SpecDecodeBaseProposer:
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
# MTP models call compute_logits via shared_head.head (a
# ParallelLMHead inside each MTP layer), not self.model.lm_head.
# If the checkpoint omits a copy of the lm_head weights at the
# MTP layer path, shared_head.head stays uninitialised and
# produces NaN logits. Always share it explicitly.
inner = getattr(self.model, "model", None)
layers = getattr(inner, "layers", None) if inner else None
if layers is not None:
items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
for layer in items:
sh = getattr(layer, "shared_head", None)
if sh is not None and hasattr(sh, "head"):
del sh.head
sh.head = target_language_model.lm_head
logger.info(
"Shared target model lm_head with MTP shared_head.head."
)
@torch.inference_mode()
def dummy_run(
self,