[Bugfix] Fix MTP accuracy for GLM-5 (#34385)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -1506,6 +1506,24 @@ class SpecDecodeBaseProposer:
|
|||||||
del self.model.lm_head
|
del self.model.lm_head
|
||||||
self.model.lm_head = target_language_model.lm_head
|
self.model.lm_head = target_language_model.lm_head
|
||||||
|
|
||||||
|
# MTP models call compute_logits via shared_head.head (a
|
||||||
|
# ParallelLMHead inside each MTP layer), not self.model.lm_head.
|
||||||
|
# If the checkpoint omits a copy of the lm_head weights at the
|
||||||
|
# MTP layer path, shared_head.head stays uninitialised and
|
||||||
|
# produces NaN logits. Always share it explicitly.
|
||||||
|
inner = getattr(self.model, "model", None)
|
||||||
|
layers = getattr(inner, "layers", None) if inner else None
|
||||||
|
if layers is not None:
|
||||||
|
items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
|
||||||
|
for layer in items:
|
||||||
|
sh = getattr(layer, "shared_head", None)
|
||||||
|
if sh is not None and hasattr(sh, "head"):
|
||||||
|
del sh.head
|
||||||
|
sh.head = target_language_model.lm_head
|
||||||
|
logger.info(
|
||||||
|
"Shared target model lm_head with MTP shared_head.head."
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def dummy_run(
|
def dummy_run(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user