Fix weight mapping test for Transfomers v5 (#33162)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-01-27 12:30:14 +00:00
committed by GitHub
parent 76139d0801
commit 14385c80fc
3 changed files with 11 additions and 4 deletions

View File

@@ -30,7 +30,12 @@ def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel:
model_cls: PreTrainedModel = getattr(transformers, model_arch)
config = AutoConfig.from_pretrained(repo)
with torch.device("meta"):
return model_cls._from_config(config)
model = model_cls._from_config(config)
# TODO(hmellor): Remove this once Transformers has fixed tied weights on meta device
# https://github.com/huggingface/transformers/issues/43522
if getattr(config.get_text_config(), "tie_word_embeddings", False):
model.tie_weights()
return model
def model_architectures_for_test() -> list[str]: