Use Transformers v5 WeightRenaming for Transformers modeling backend (#31545)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-13 20:49:08 +00:00
committed by GitHub
parent d0b402974f
commit 0005d2a3c9
7 changed files with 162 additions and 89 deletions

View File

@@ -17,6 +17,7 @@
"""Transformers modeling backend base class."""
from collections.abc import Iterable
from itertools import chain
from typing import TYPE_CHECKING
import regex as re
@@ -107,27 +108,6 @@ class Base(
SupportsEagle3,
):
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Add `model.` prefix for base model checkpoints,
# handling the case where it is already present
"": "model.",
"model.model.": "model.",
# Heads will be adjacent to `model` (pooling included because of adapters)
"model.lm_head.": "lm_head.",
"model.score.": "classifier.",
"model.classifier.": "classifier.",
}
)
def __init_subclass__(cls, *args, **kwargs):
"""Merge hf_to_vllm_mapper in MRO from most specific to least specific."""
super().__init_subclass__(*args, **kwargs)
hf_to_vllm_mapper = WeightsMapper()
for base in cls.__mro__:
if base_hf_to_vllm_mapper := getattr(base, "hf_to_vllm_mapper", None):
hf_to_vllm_mapper |= base_hf_to_vllm_mapper
cls.hf_to_vllm_mapper = hf_to_vllm_mapper
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__()
@@ -174,8 +154,8 @@ class Base(
if "gptq" in quant_method_name:
self.ignore_unexpected_suffixes.append(".bias")
# Set correct attn and init on "meta" to delay allocating GPU tensors
self.text_config._attn_implementation = "vllm"
# Patch config and init on "meta" to delay allocating GPU tensors
self._patch_config()
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
@@ -183,6 +163,8 @@ class Base(
trust_remote_code=self.model_config.trust_remote_code,
)
# Create weight name to module qualname mapper
self._create_hf_to_vllm_mapper()
# Remove layers not on this pipeline parallel rank
self.pipeline_parallel()
# Substitute remaining layers with vLLM's layers as needed
@@ -216,6 +198,104 @@ class Base(
["hidden_states"], self.text_config.hidden_size
)
def _patch_config(self):
"""
Patch the config to ensure that the model is created correctly:
- Sets the attention implementation to "vllm" so the attention instances from
`create_attention_instances` are used
- Sets the dtype to the default torch dtype set by vLLM because Transformers
uses the config dtype when creating the model
- Propagates this dtype to any sub-configs because Transformers model
implementations do not support/use different dtypes in sub-models
"""
self.text_config._attn_implementation = "vllm"
self.config.dtype = torch.get_default_dtype()
# TODO(hmellor): Remove this when Transformers v4 support is dropped
for sub_config_name in getattr(self.config, "sub_configs", {}):
sub_config = getattr(self.config, sub_config_name)
if sub_config.dtype != (dtype := self.config.dtype):
sub_config.dtype = dtype
def _create_hf_to_vllm_mapper(self):
"""
Create a WeightsMapper to map checkpoint weight names to module qualnames.
This handles:
- Transformers weight renaming:
- from `WeightRenaming` in Transformers v5
- from `_checkpoint_conversion_mapping` in Transformers v4
- Checkpoints saved with a base model prefix that is not `model`
- Checkpoints saved with no base model prefix
- Any quantization config specific mappings
"""
self.hf_to_vllm_mapper = WeightsMapper()
orig_to_new_regex = self.hf_to_vllm_mapper.orig_to_new_regex
if Version(transformers.__version__) >= Version("5.0.0"):
from transformers.conversion_mapping import (
WeightRenaming,
get_model_conversion_mapping,
)
for mapping in get_model_conversion_mapping(self.model):
# Handle weights which have been renamed in Transformers
if isinstance(mapping, WeightRenaming):
# Recompile using regex (Transformers used re)
compiled_sources = re.compile(
mapping.compiled_sources.pattern, mapping.compiled_sources.flags
)
target_pattern = mapping.target_patterns[0]
orig_to_new_regex[compiled_sources] = target_pattern
# TODO: Handle WeightConverter to enable layer merging
else:
# Replace legacy suffixes used for norms
# TODO(hmellor): Remove this when Transformers v4 support is dropped
orig_to_new_regex.update(
{
re.compile(r"\.gamma$"): ".weight",
re.compile(r"\.beta$"): ".bias",
}
)
# Handle weights which have been renamed in Transformers
# TODO(hmellor): Remove this when Transformers v4 support is dropped
ccm = getattr(self.model, "_checkpoint_conversion_mapping", {})
for source, target in ccm.items():
orig_to_new_regex[re.compile(source)] = target
# Handle unexpected weights which should be ignored
if self.model._keys_to_ignore_on_load_unexpected is not None:
for key in self.model._keys_to_ignore_on_load_unexpected:
orig_to_new_regex[re.compile(key)] = None
# Standardise base model prefix
bmp = self.model.base_model_prefix
expected_bmp = r"model.\1"
# Handle checkpoints saved with different base model prefix
if bmp and bmp != "model":
different_bmp_pattern = re.compile(rf"^{bmp}\.(.+)")
orig_to_new_regex[different_bmp_pattern] = expected_bmp
# Handle direct children of self.model which were saved without the model prefix
direct_children = chain(
self.model.named_children(),
self.model.named_parameters(recurse=False),
self.model.named_buffers(recurse=False),
)
model_children = "|".join(name for name, _ in direct_children)
missing_bmp_pattern = re.compile(rf"^(?!model\.)(({model_children}).*)")
orig_to_new_regex[missing_bmp_pattern] = expected_bmp
# Handle weights saved as direct children of self.model which no longer are
unexpected_bmp_pattern = re.compile(rf"^(model\.)((?!{model_children}).+)")
orig_to_new_regex[unexpected_bmp_pattern] = r"\2"
# Handle lm_head which was saved inside the base model
nested_lm_head_pattern = re.compile(r"^model\.(.+\.)*(lm_head.+)")
orig_to_new_regex[nested_lm_head_pattern] = r"\2"
# Apply mapping to quantization config if needed
self._maybe_apply_model_mapping()
def pipeline_parallel(self):
"""
Apply the model's pipeline parallelization plan.

View File

@@ -20,7 +20,6 @@ from typing import TYPE_CHECKING
import torch
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
@@ -28,20 +27,6 @@ if TYPE_CHECKING:
class LegacyMixin:
hf_to_vllm_mapper = WeightsMapper(
# These are applied in order, so the order matters!
orig_to_new_prefix={
# Handle BERT-like models
"roberta": "model",
"bert": "model",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
},
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

View File

@@ -24,7 +24,6 @@ import torch
from vllm.config.utils import getattr_iter
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MultiModalKwargsItems
from vllm.multimodal.inputs import (
MultiModalDataDict,
@@ -273,30 +272,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
supports_multimodal_raw_input_only = True
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"visual": "model.visual",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# Qwen models used "model" as the name for the language model.
# Therefore, we must map each of submodule explicitly to avoid
# conflicts with newer models that use "model.language_model".
"model.embed_tokens": "model.language_model.embed_tokens",
"model.layers": "model.language_model.layers",
"model.norm": "model.language_model.norm",
}
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip SupportsMRoPE.__init__ and call the next class in MRO
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)