Use Transformers v5 WeightRenaming for Transformers modeling backend (#31545)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
"""Transformers modeling backend base class."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import regex as re
|
||||
@@ -107,27 +108,6 @@ class Base(
|
||||
SupportsEagle3,
|
||||
):
|
||||
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# Add `model.` prefix for base model checkpoints,
|
||||
# handling the case where it is already present
|
||||
"": "model.",
|
||||
"model.model.": "model.",
|
||||
# Heads will be adjacent to `model` (pooling included because of adapters)
|
||||
"model.lm_head.": "lm_head.",
|
||||
"model.score.": "classifier.",
|
||||
"model.classifier.": "classifier.",
|
||||
}
|
||||
)
|
||||
|
||||
def __init_subclass__(cls, *args, **kwargs):
|
||||
"""Merge hf_to_vllm_mapper in MRO from most specific to least specific."""
|
||||
super().__init_subclass__(*args, **kwargs)
|
||||
hf_to_vllm_mapper = WeightsMapper()
|
||||
for base in cls.__mro__:
|
||||
if base_hf_to_vllm_mapper := getattr(base, "hf_to_vllm_mapper", None):
|
||||
hf_to_vllm_mapper |= base_hf_to_vllm_mapper
|
||||
cls.hf_to_vllm_mapper = hf_to_vllm_mapper
|
||||
|
||||
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
super().__init__()
|
||||
@@ -174,8 +154,8 @@ class Base(
|
||||
if "gptq" in quant_method_name:
|
||||
self.ignore_unexpected_suffixes.append(".bias")
|
||||
|
||||
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
||||
self.text_config._attn_implementation = "vllm"
|
||||
# Patch config and init on "meta" to delay allocating GPU tensors
|
||||
self._patch_config()
|
||||
with init_on_device_without_buffers("meta"):
|
||||
self.model: PreTrainedModel = AutoModel.from_config(
|
||||
self.config,
|
||||
@@ -183,6 +163,8 @@ class Base(
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
# Create weight name to module qualname mapper
|
||||
self._create_hf_to_vllm_mapper()
|
||||
# Remove layers not on this pipeline parallel rank
|
||||
self.pipeline_parallel()
|
||||
# Substitute remaining layers with vLLM's layers as needed
|
||||
@@ -216,6 +198,104 @@ class Base(
|
||||
["hidden_states"], self.text_config.hidden_size
|
||||
)
|
||||
|
||||
def _patch_config(self):
|
||||
"""
|
||||
Patch the config to ensure that the model is created correctly:
|
||||
|
||||
- Sets the attention implementation to "vllm" so the attention instances from
|
||||
`create_attention_instances` are used
|
||||
- Sets the dtype to the default torch dtype set by vLLM because Transformers
|
||||
uses the config dtype when creating the model
|
||||
- Propagates this dtype to any sub-configs because Transformers model
|
||||
implementations do not support/use different dtypes in sub-models
|
||||
"""
|
||||
self.text_config._attn_implementation = "vllm"
|
||||
self.config.dtype = torch.get_default_dtype()
|
||||
# TODO(hmellor): Remove this when Transformers v4 support is dropped
|
||||
for sub_config_name in getattr(self.config, "sub_configs", {}):
|
||||
sub_config = getattr(self.config, sub_config_name)
|
||||
if sub_config.dtype != (dtype := self.config.dtype):
|
||||
sub_config.dtype = dtype
|
||||
|
||||
def _create_hf_to_vllm_mapper(self):
|
||||
"""
|
||||
Create a WeightsMapper to map checkpoint weight names to module qualnames.
|
||||
|
||||
This handles:
|
||||
|
||||
- Transformers weight renaming:
|
||||
- from `WeightRenaming` in Transformers v5
|
||||
- from `_checkpoint_conversion_mapping` in Transformers v4
|
||||
- Checkpoints saved with a base model prefix that is not `model`
|
||||
- Checkpoints saved with no base model prefix
|
||||
- Any quantization config specific mappings
|
||||
"""
|
||||
self.hf_to_vllm_mapper = WeightsMapper()
|
||||
orig_to_new_regex = self.hf_to_vllm_mapper.orig_to_new_regex
|
||||
|
||||
if Version(transformers.__version__) >= Version("5.0.0"):
|
||||
from transformers.conversion_mapping import (
|
||||
WeightRenaming,
|
||||
get_model_conversion_mapping,
|
||||
)
|
||||
|
||||
for mapping in get_model_conversion_mapping(self.model):
|
||||
# Handle weights which have been renamed in Transformers
|
||||
if isinstance(mapping, WeightRenaming):
|
||||
# Recompile using regex (Transformers used re)
|
||||
compiled_sources = re.compile(
|
||||
mapping.compiled_sources.pattern, mapping.compiled_sources.flags
|
||||
)
|
||||
target_pattern = mapping.target_patterns[0]
|
||||
orig_to_new_regex[compiled_sources] = target_pattern
|
||||
# TODO: Handle WeightConverter to enable layer merging
|
||||
else:
|
||||
# Replace legacy suffixes used for norms
|
||||
# TODO(hmellor): Remove this when Transformers v4 support is dropped
|
||||
orig_to_new_regex.update(
|
||||
{
|
||||
re.compile(r"\.gamma$"): ".weight",
|
||||
re.compile(r"\.beta$"): ".bias",
|
||||
}
|
||||
)
|
||||
|
||||
# Handle weights which have been renamed in Transformers
|
||||
# TODO(hmellor): Remove this when Transformers v4 support is dropped
|
||||
ccm = getattr(self.model, "_checkpoint_conversion_mapping", {})
|
||||
for source, target in ccm.items():
|
||||
orig_to_new_regex[re.compile(source)] = target
|
||||
|
||||
# Handle unexpected weights which should be ignored
|
||||
if self.model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for key in self.model._keys_to_ignore_on_load_unexpected:
|
||||
orig_to_new_regex[re.compile(key)] = None
|
||||
|
||||
# Standardise base model prefix
|
||||
bmp = self.model.base_model_prefix
|
||||
expected_bmp = r"model.\1"
|
||||
# Handle checkpoints saved with different base model prefix
|
||||
if bmp and bmp != "model":
|
||||
different_bmp_pattern = re.compile(rf"^{bmp}\.(.+)")
|
||||
orig_to_new_regex[different_bmp_pattern] = expected_bmp
|
||||
# Handle direct children of self.model which were saved without the model prefix
|
||||
direct_children = chain(
|
||||
self.model.named_children(),
|
||||
self.model.named_parameters(recurse=False),
|
||||
self.model.named_buffers(recurse=False),
|
||||
)
|
||||
model_children = "|".join(name for name, _ in direct_children)
|
||||
missing_bmp_pattern = re.compile(rf"^(?!model\.)(({model_children}).*)")
|
||||
orig_to_new_regex[missing_bmp_pattern] = expected_bmp
|
||||
# Handle weights saved as direct children of self.model which no longer are
|
||||
unexpected_bmp_pattern = re.compile(rf"^(model\.)((?!{model_children}).+)")
|
||||
orig_to_new_regex[unexpected_bmp_pattern] = r"\2"
|
||||
# Handle lm_head which was saved inside the base model
|
||||
nested_lm_head_pattern = re.compile(r"^model\.(.+\.)*(lm_head.+)")
|
||||
orig_to_new_regex[nested_lm_head_pattern] = r"\2"
|
||||
|
||||
# Apply mapping to quantization config if needed
|
||||
self._maybe_apply_model_mapping()
|
||||
|
||||
def pipeline_parallel(self):
|
||||
"""
|
||||
Apply the model's pipeline parallelization plan.
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -28,20 +27,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class LegacyMixin:
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
# These are applied in order, so the order matters!
|
||||
orig_to_new_prefix={
|
||||
# Handle BERT-like models
|
||||
"roberta": "model",
|
||||
"bert": "model",
|
||||
},
|
||||
orig_to_new_suffix={
|
||||
# Replace legacy suffixes used for norms
|
||||
".gamma": ".weight",
|
||||
".beta": ".bias",
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ import torch
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.multimodal import MultiModalKwargsItems
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
@@ -273,30 +272,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||
class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
supports_multimodal_raw_input_only = True
|
||||
|
||||
# Backwards compatibility for prev released models. State dicts back then
|
||||
# had different formats and cannot be loaded with `AutoModel` mapping as is
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"language_model.model": "model.language_model",
|
||||
"text_model.model": "model.text_model",
|
||||
"vision_tower": "model.vision_tower",
|
||||
"vqmodel": "model.vqmodel",
|
||||
"visual": "model.visual",
|
||||
"vision_model": "model.vision_model",
|
||||
"vision_embed_tokens": "model.vision_embed_tokens",
|
||||
"image_newline": "model.image_newline",
|
||||
"multi_modal_projector": "model.multi_modal_projector",
|
||||
"text_model.lm_head": "lm_head",
|
||||
"language_model.lm_head": "lm_head",
|
||||
# Qwen models used "model" as the name for the language model.
|
||||
# Therefore, we must map each of submodule explicitly to avoid
|
||||
# conflicts with newer models that use "model.language_model".
|
||||
"model.embed_tokens": "model.language_model.embed_tokens",
|
||||
"model.layers": "model.language_model.layers",
|
||||
"model.norm": "model.language_model.norm",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
# Skip SupportsMRoPE.__init__ and call the next class in MRO
|
||||
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
Reference in New Issue
Block a user