[Bugfix] Fix models and tests for transformers v5 (#33977)
Signed-off-by: raushan <raushan@huggingface.co> Signed-off-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
51a7bda625
commit
85ee1d962b
@@ -33,7 +33,9 @@ def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel:
|
||||
model = model_cls._from_config(config)
|
||||
# TODO(hmellor): Remove this once Transformers has fixed tied weights on meta device
|
||||
# https://github.com/huggingface/transformers/issues/43522
|
||||
if getattr(config.get_text_config(), "tie_word_embeddings", False):
|
||||
if getattr(config.get_text_config(), "tie_word_embeddings", False) or getattr(
|
||||
config, "tie_word_embeddings", False
|
||||
):
|
||||
model.tie_weights()
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user