Fix weight mapping test for Transfomers v5 (#33162)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -249,7 +249,8 @@ class Base(
|
||||
# Layers before module list
|
||||
for name in pp_plan[:module_list_idx]:
|
||||
if self.pp_group.is_first_rank or (
|
||||
self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
|
||||
getattr(self.text_config, "tie_word_embeddings", False)
|
||||
and self.pp_group.is_last_rank
|
||||
):
|
||||
continue
|
||||
setattr(self.model, name, PPMissingLayer())
|
||||
|
||||
@@ -38,7 +38,8 @@ class CausalMixin(VllmModelForTextGeneration):
|
||||
|
||||
# Tell `Base.load_weights` to skip
|
||||
# `lm_head` if the model has tied word embeddings
|
||||
if self.text_config.tie_word_embeddings:
|
||||
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
|
||||
if tie_word_embeddings:
|
||||
self.skip_prefixes.append("lm_head.")
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
@@ -48,7 +49,7 @@ class CausalMixin(VllmModelForTextGeneration):
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if self.text_config.tie_word_embeddings:
|
||||
if tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(
|
||||
self.model.get_input_embeddings()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user