Fix weight mapping test for Transfomers v5 (#33162)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -30,7 +30,12 @@ def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel:
|
||||
model_cls: PreTrainedModel = getattr(transformers, model_arch)
|
||||
config = AutoConfig.from_pretrained(repo)
|
||||
with torch.device("meta"):
|
||||
return model_cls._from_config(config)
|
||||
model = model_cls._from_config(config)
|
||||
# TODO(hmellor): Remove this once Transformers has fixed tied weights on meta device
|
||||
# https://github.com/huggingface/transformers/issues/43522
|
||||
if getattr(config.get_text_config(), "tie_word_embeddings", False):
|
||||
model.tie_weights()
|
||||
return model
|
||||
|
||||
|
||||
def model_architectures_for_test() -> list[str]:
|
||||
|
||||
Reference in New Issue
Block a user