From d5816c8c2fa8dba84dc518c481a21bc6e5439acb Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 11 Mar 2026 15:10:26 +0000 Subject: [PATCH] Fix tied weights in weight mapping test for Transformers v5 (#36788) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/models/multimodal/test_mapping.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/models/multimodal/test_mapping.py b/tests/models/multimodal/test_mapping.py index 1b7e530f3..8d4ccaf4e 100644 --- a/tests/models/multimodal/test_mapping.py +++ b/tests/models/multimodal/test_mapping.py @@ -31,12 +31,6 @@ def create_dummy_model(repo: str, model_arch: str) -> PreTrainedModel: config = AutoConfig.from_pretrained(repo) with torch.device("meta"): model = model_cls._from_config(config) - # TODO(hmellor): Remove this once Transformers has fixed tied weights on meta device - # https://github.com/huggingface/transformers/issues/43522 - if getattr(config.get_text_config(), "tie_word_embeddings", False) or getattr( - config, "tie_word_embeddings", False - ): - model.tie_weights() return model @@ -103,6 +97,15 @@ def test_hf_model_weights_mapper(model_arch: str): # Some checkpoints may have buffers, we ignore them for this test ref_weight_names -= buffer_names + # Some checkpoints include tied weights (e.g. lm_head tied to embed_tokens) in the + # safetensors file. In Transformers v5, named_parameters() will not include them + # after they are tied in the model, so the mapper will not be able to map them. + # We exclude them from the reference weight names for this test. + if isinstance(tied := getattr(hf_dummy_model, "_tied_weights_keys", None), dict): + mapped_tied_weights = mapper.apply((k, None) for k in tied) + tied_weight_names = set(map(lambda x: x[0], mapped_tied_weights)) + ref_weight_names -= tied_weight_names + weights_missing = ref_weight_names - weight_names weights_unmapped = weight_names - ref_weight_names assert not weights_missing and not weights_unmapped, (