Fix weight mapping test for Transfomers v5 (#33162)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-01-27 12:30:14 +00:00
committed by GitHub
parent 76139d0801
commit 14385c80fc
3 changed files with 11 additions and 4 deletions

View File

@@ -249,7 +249,8 @@ class Base(
# Layers before module list
for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or (
self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
getattr(self.text_config, "tie_word_embeddings", False)
and self.pp_group.is_last_rank
):
continue
setattr(self.model, name, PPMissingLayer())

View File

@@ -38,7 +38,8 @@ class CausalMixin(VllmModelForTextGeneration):
# Tell `Base.load_weights` to skip
# `lm_head` if the model has tied word embeddings
if self.text_config.tie_word_embeddings:
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
if tie_word_embeddings:
self.skip_prefixes.append("lm_head.")
if self.pp_group.is_last_rank:
@@ -48,7 +49,7 @@ class CausalMixin(VllmModelForTextGeneration):
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.text_config.tie_word_embeddings:
if tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings()
)