Fix tied weights in weight mapping test for Transformers v5 (#36788)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -31,12 +31,6 @@ def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel:
|
||||
config = AutoConfig.from_pretrained(repo)
|
||||
with torch.device("meta"):
|
||||
model = model_cls._from_config(config)
|
||||
# TODO(hmellor): Remove this once Transformers has fixed tied weights on meta device
|
||||
# https://github.com/huggingface/transformers/issues/43522
|
||||
if getattr(config.get_text_config(), "tie_word_embeddings", False) or getattr(
|
||||
config, "tie_word_embeddings", False
|
||||
):
|
||||
model.tie_weights()
|
||||
return model
|
||||
|
||||
|
||||
@@ -103,6 +97,15 @@ def test_hf_model_weights_mapper(model_arch: str):
|
||||
# Some checkpoints may have buffers, we ignore them for this test
|
||||
ref_weight_names -= buffer_names
|
||||
|
||||
# Some checkpoints include tied weights (e.g. lm_head tied to embed_tokens) in the
|
||||
# safetensors file. In Transformers v5, named_parameters() will not include them
|
||||
# after they are tied in the model, so the mapper will not be able to map them.
|
||||
# We exclude them from the reference weight names for this test.
|
||||
if isinstance(tied := getattr(hf_dummy_model, "_tied_weights_keys", None), dict):
|
||||
mapped_tied_weights = mapper.apply((k, None) for k in tied)
|
||||
tied_weight_names = set(map(lambda x: x[0], mapped_tied_weights))
|
||||
ref_weight_names -= tied_weight_names
|
||||
|
||||
weights_missing = ref_weight_names - weight_names
|
||||
weights_unmapped = weight_names - ref_weight_names
|
||||
assert not weights_missing and not weights_unmapped, (
|
||||
|
||||
Reference in New Issue
Block a user