# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright 2024 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Transformers modeling backend base class.""" from collections.abc import Callable, Iterable from itertools import chain from operator import attrgetter from typing import TYPE_CHECKING import regex as re import torch import transformers from packaging.version import Version from torch import nn from transformers import AutoModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from vllm.config.utils import getattr_iter from vllm.distributed import get_pp_group, get_tp_group from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger from vllm.model_executor.layers.attention import ( Attention, EncoderOnlyAttention, ) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.models.interfaces import ( SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP, SupportsQuant, ) from vllm.model_executor.models.interfaces_base import VllmModel from vllm.model_executor.models.transformers.utils import ( get_feature_request_tip, init_on_device_without_buffers, log_replacement, replace_conv_class, replace_linear_class, replace_rms_norm_class, ) from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, WeightsMapper, make_empty_intermediate_tensors_factory, maybe_prefix, ) from vllm.sequence import IntermediateTensors from vllm.v1.attention.backend import AttentionType if TYPE_CHECKING: from transformers import PreTrainedModel from vllm.config import VllmConfig else: PreTrainedModel = object logger = init_logger(__name__) def vllm_flash_attention_forward( # Transformers args module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, # Transformers kwargs scaling: float | None = None, # vLLM kwargs attention_instances: dict[int, Attention] | None = None, **kwargs, ): self_attn = attention_instances[module.layer_idx] if scaling is not None: self_attn.impl.scale = float(scaling) hidden = query.shape[-2] query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) return self_attn.forward(query, key, value), None ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward class Base( nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3, ): embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): super().__init__() logger.info("Using Transformers modeling backend.") self.config = vllm_config.model_config.hf_config self.text_config = self.config.get_text_config() self.cache_config = vllm_config.cache_config self.device_config = vllm_config.device_config self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.quant_config = vllm_config.quant_config self.pp_group = get_pp_group() self.tp_group = get_tp_group() # Attrs for weight loading (see self.load_weights) self.skip_prefixes: list[str] = [] """Skip loading weights whose qualname starts with these prefixes.""" self.skip_substrs: list[str] = [] """Skip loading weights whose qualname contains these substrings.""" self.ignore_unexpected_prefixes: list[str] = [] """Ignore unexpected weights whose qualname starts with these prefixes.""" self.ignore_unexpected_suffixes: list[str] = [] """Ignore unexpected weights whose qualname ends with these suffixes.""" # Attrs for Eagle3 (see self.set_aux_hidden_state_layers) self._target_class: type[nn.Module] = nn.Module """Target class for Eagle3 aux hidden state recording.""" self._layer_names: dict[int, str] = {} """Mapping from layer index to layer name for Eagle3.""" self._output_aux_hidden_states_kwargs: dict[str, bool] = {} """Kwargs to pass to model forward for Eagle3 aux hidden states.""" if self.quant_config: quant_method_name = self.quant_config.get_name() # Check for unsupported quantization methods. if quant_method_name == "mxfp4": raise NotImplementedError( "Transformers modeling backend does " "not support MXFP4 quantization yet." ) # Skip loading extra bias for GPTQ models. if "gptq" in quant_method_name: self.ignore_unexpected_suffixes.append(".bias") # Patch config and init on "meta" to delay allocating GPU tensors self._patch_config() with init_on_device_without_buffers("meta"): self.model: PreTrainedModel = AutoModel.from_config( self.config, dtype=self.model_config.dtype, trust_remote_code=self.model_config.trust_remote_code, ) # Create weight name to module qualname mapper self._create_hf_to_vllm_mapper() # Remove layers not on this pipeline parallel rank self.pipeline_parallel() # Substitute remaining layers with vLLM's layers as needed self.recursive_replace() # Create attention instances for KV cache allocation self.attention_instances = self.create_attention_instances() # Input embeddings self.embed_scale = None input_embeddings = self.model.get_input_embeddings() if not isinstance(input_embeddings, PPMissingLayer): # Some models scale embeddings inside the input embedding layer self.embed_scale = getattr(input_embeddings, "embed_scale", None) names = ("embedding_size", "hidden_size") embedding_dim = getattr_iter(self.text_config, names, None) assert embedding_dim is not None self.model.set_input_embeddings( VocabParallelEmbedding( self.text_config.vocab_size, embedding_dim=embedding_dim, org_num_embeddings=self.text_config.vocab_size, quant_config=self.quant_config, ) ) # Initialize any parameters that have not had their modules replaced self.init_parameters(self.model) # Pipeline parallel intermediate tensors self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states"], self.text_config.hidden_size ) def _patch_config(self): """ Patch the config to ensure that the model is created correctly: - Sets the attention implementation to "vllm" so the attention instances from `create_attention_instances` are used - Sets the dtype to the default torch dtype set by vLLM because Transformers uses the config dtype when creating the model - Propagates this dtype to any sub-configs because Transformers model implementations do not support/use different dtypes in sub-models """ self.text_config._attn_implementation = "vllm" self.config.dtype = torch.get_default_dtype() # TODO(hmellor): Remove this when Transformers v4 support is dropped for sub_config_name in getattr(self.config, "sub_configs", {}): sub_config = getattr(self.config, sub_config_name) if sub_config.dtype != (dtype := self.config.dtype): sub_config.dtype = dtype def _create_hf_to_vllm_mapper(self): """ Create a WeightsMapper to map checkpoint weight names to module qualnames. This handles: - Transformers weight renaming: - from `WeightRenaming` in Transformers v5 - from `_checkpoint_conversion_mapping` in Transformers v4 - Checkpoints saved with a base model prefix that is not `model` - Checkpoints saved with no base model prefix - Any quantization config specific mappings """ self.hf_to_vllm_mapper = WeightsMapper() orig_to_new_regex = self.hf_to_vllm_mapper.orig_to_new_regex if Version(transformers.__version__) >= Version("5.0.0"): from transformers.conversion_mapping import ( WeightRenaming, get_model_conversion_mapping, ) for mapping in get_model_conversion_mapping(self.model): # Handle weights which have been renamed in Transformers if isinstance(mapping, WeightRenaming): # Recompile using regex (Transformers used re) compiled_sources = re.compile( mapping.compiled_sources.pattern, mapping.compiled_sources.flags ) target_pattern = mapping.target_patterns[0] orig_to_new_regex[compiled_sources] = target_pattern # TODO: Handle WeightConverter to enable layer merging else: # Replace legacy suffixes used for norms # TODO(hmellor): Remove this when Transformers v4 support is dropped orig_to_new_regex.update( { re.compile(r"\.gamma$"): ".weight", re.compile(r"\.beta$"): ".bias", } ) # Handle weights which have been renamed in Transformers # TODO(hmellor): Remove this when Transformers v4 support is dropped ccm = getattr(self.model, "_checkpoint_conversion_mapping", {}) for source, target in ccm.items(): orig_to_new_regex[re.compile(source)] = target # Handle unexpected weights which should be ignored if self.model._keys_to_ignore_on_load_unexpected is not None: for key in self.model._keys_to_ignore_on_load_unexpected: orig_to_new_regex[re.compile(key)] = None # Standardise base model prefix bmp = self.model.base_model_prefix expected_bmp = r"model.\1" # Handle checkpoints saved with different base model prefix if bmp and bmp != "model": different_bmp_pattern = re.compile(rf"^{bmp}\.(.+)") orig_to_new_regex[different_bmp_pattern] = expected_bmp # Handle direct children of self.model which were saved without the model prefix direct_children = chain( self.model.named_children(), self.model.named_parameters(recurse=False), self.model.named_buffers(recurse=False), ) model_children = "|".join(name for name, _ in direct_children) missing_bmp_pattern = re.compile(rf"^(?!model\.)(({model_children}).*)") orig_to_new_regex[missing_bmp_pattern] = expected_bmp # Handle weights saved as direct children of self.model which no longer are unexpected_bmp_pattern = re.compile(rf"^(model\.)((?!{model_children}).+)") orig_to_new_regex[unexpected_bmp_pattern] = r"\2" # Handle lm_head which was saved inside the base model nested_lm_head_pattern = re.compile(r"^model\.(.+\.)*(lm_head.+)") orig_to_new_regex[nested_lm_head_pattern] = r"\2" # Apply mapping to quantization config if needed self._maybe_apply_model_mapping() def _get_tie_word_embeddings(self): """ Check if the model has tied word embeddings. """ # Transformers v4 and v5 will store this in different places tie_word_embeddings_v4 = getattr(self.text_config, "tie_word_embeddings", False) tie_word_embeddings_v5 = getattr(self.config, "tie_word_embeddings", False) return tie_word_embeddings_v4 or tie_word_embeddings_v5 def pipeline_parallel(self): """ Apply the model's pipeline parallelization plan. """ if self.pp_group.world_size <= 1: return if not self.model.supports_pp_plan: tip = get_feature_request_tip( self.model_config.model, self.model_config.trust_remote_code ) raise ValueError( f"{type(self.model)} does not support pipeline parallel. {tip}" ) def attrsetter(attr: str) -> Callable[[object, object], None]: """Set a possibly nested attribute, like the inverse of attrgetter.""" parent, _, name = attr.rpartition(".") def setter(obj: object, value: object): attr_parent = attrgetter(parent)(obj) if parent else obj setattr(attr_parent, name, value) return setter module_lists = [] module_list_idx = None pp_plan = list(self.model._pp_plan.keys()) for i, name in enumerate(pp_plan): # attrgetter in case the module is nested (e.g. "text_model.layers") if isinstance(attrgetter(name)(self.model), nn.ModuleList): module_lists.append(name) module_list_idx = i if len(module_lists) > 1: raise ValueError( "Pipeline parallel of models with multiple `ModuleList`s " "in the base model are not supported yet!" ) if module_list_idx is None: raise ValueError(f"Could not find `ModuleList` in {type(self.model)}") # Layers before module list for name in pp_plan[:module_list_idx]: if self.pp_group.is_first_rank or ( self._get_tie_word_embeddings() and self.pp_group.is_last_rank ): continue # attrsetter in case the module is nested (e.g. "text_model.embed_tokens") attrsetter(name)(self.model, PPMissingLayer()) # Module list start_layer, end_layer = get_pp_indices( self.text_config.num_hidden_layers, self.pp_group.rank_in_group, self.pp_group.world_size, ) layers_name = pp_plan[module_list_idx] # attrgetter in case the module is nested (e.g. "text_model.layers") layers = attrgetter(layers_name)(self.model) for i in range(len(layers)): if start_layer <= i and i < end_layer: continue layers[i] = PPMissingLayer() # Layers after module list for name in pp_plan[module_list_idx + 1 :]: # Modules that should be on last rank if not self.pp_group.is_last_rank: # attrsetter in case the module is nested (e.g. "text_model.norm") attrsetter(name)(self.model, PPMissingLayer()) def recursive_replace(self): """Recursively replace modules in the model as needed. Currently, this replaces: - `nn.Linear` with vLLM's tensor parallel linear classes - `*RMSNorm` with vLLM's `RMSNorm` """ tp_plan = self.model.tp_plan if not tp_plan and self.tp_group.world_size > 1: tip = get_feature_request_tip( self.model_config.model, self.model_config.trust_remote_code ) raise ValueError( f"{type(self.model)} does not support tensor parallel. {tip}" ) # Prefix the patterns because we always start from `self.model` tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()} def _recursive_replace(module: nn.Module, prefix: str): for child_name, child_module in module.named_children(): new_module = child_module qual_name = maybe_prefix(prefix, child_name) if ( isinstance(module, nn.ModuleList) and len(module) == self.text_config.num_hidden_layers ): # Populate Eagle3 attrs self._target_class = type(child_module) layer_name = qual_name.removeprefix("model.") self._layer_names[int(child_name)] = layer_name # MTP weights should not be loaded into the base model num_hidden_layers = self.text_config.num_hidden_layers names = ( "n_predict", # Override from SpeculativeConfig "num_nextn_predict_layers", # Most models "mtp_num_hidden_layers", # Qwen 3.5 ) n_predict = getattr_iter(self.text_config, names, 0) for i in range(num_hidden_layers, num_hidden_layers + n_predict): mtp_prefix = f"{prefix}.{i}." if mtp_prefix not in self.ignore_unexpected_prefixes: self.ignore_unexpected_prefixes.append(mtp_prefix) # Replace modules as needed if isinstance(child_module, nn.Linear): generator = (p for p in tp_plan if re.match(p, qual_name)) pattern = next(generator, None) # Some weight loaders expect all linear layers to inherit # LinearBase, so we set a default style which causes any # unspecified layers to be replaced with ReplicatedLinear style = tp_plan.get(pattern, "replicate") new_module = replace_linear_class( child_module, style, self.quant_config, prefix=qual_name ) elif isinstance(child_module, (nn.Conv2d, nn.Conv3d)): new_module = replace_conv_class(child_module) elif child_module.__class__.__name__.endswith("RMSNorm"): new_module = replace_rms_norm_class( child_module, self.text_config.hidden_size ) else: _recursive_replace(child_module, prefix=qual_name) if new_module is not child_module: setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) _recursive_replace(self.model, prefix="model") def create_attention_instances(self) -> dict[int, Attention]: """ Create `Attention` instances to inform KV cache allocation. """ text_config = self.text_config num_heads = self.model_config.get_num_attention_heads(self.parallel_config) head_size = self.model_config.get_head_size() num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) logits_soft_cap = getattr(text_config, "attn_logit_softcapping", None) # In encoder models, the attention layers will have `is_causal=False` is_encoder = lambda module: not getattr(module, "is_causal", True) has_encoder = lambda model: any(is_encoder(m) for m in model.modules()) is_multimodal = lambda config: config != config.get_text_config() # vLLM does not support encoder-decoder models, so if any encoder layer is # found in a text only model, we assume the whole model is an encoder model if has_encoder(self.model) and not is_multimodal(self.config): self.check_version("5.0.0", "encoder models support") attn_type = AttentionType.ENCODER_ONLY else: attn_type = AttentionType.DECODER pp_rank = self.pp_group.rank_in_group pp_size = self.pp_group.world_size start, end = get_pp_indices(text_config.num_hidden_layers, pp_rank, pp_size) attention_instances = {} for i in range(start, end): # Handle interleaved sliding window attention per_layer_sliding_window = None if ( hasattr(self.config, "layer_types") and self.config.layer_types[i] == "sliding_attention" ): per_layer_sliding_window = self.config.sliding_window attn_cls = ( EncoderOnlyAttention if attn_type == AttentionType.ENCODER_ONLY else Attention ) attention_instances[i] = attn_cls( num_heads=num_heads, head_size=head_size, # NOTE: We use Llama scale as default, if it's set by # Transformers, it's updated in vllm_flash_attention_forward scale=head_size**-0.5, num_kv_heads=num_kv_heads, cache_config=self.cache_config, quant_config=self.quant_config, logits_soft_cap=logits_soft_cap, per_layer_sliding_window=per_layer_sliding_window, prefix=f"{i}.attn", attn_type=attn_type, ) return attention_instances def init_parameters(self, module: nn.Module, dtype: torch.dtype | None = None): """ If a `parameter` is on the `meta` device, then its parent `module` is the original module created by: ```python with torch.device("meta"): self.model: "PreTrainedModel" = AutoModel.from_config(...) ``` """ def _init_parameters(module: nn.Module, dtype: torch.dtype | None): for name, param in module.named_parameters(recurse=False): if param.device == torch.device("meta"): new_param = nn.Parameter( torch.empty_like( param.data, dtype=dtype or self.model_config.dtype, device=self.device_config.device, ) ) setattr(module, name, new_param) for child in module.children(): _init_parameters(child, dtype) _init_parameters(module, dtype) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: inputs_embeds = self.model.get_input_embeddings()(input_ids) if self.embed_scale is not None: inputs_embeds *= self.embed_scale return inputs_embeds def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor | IntermediateTensors: if not self.pp_group.is_first_rank: assert intermediate_tensors is not None input_ids = None inputs_embeds = intermediate_tensors["hidden_states"] if input_ids is not None: input_ids = input_ids[None, ...] if inputs_embeds is not None: inputs_embeds = inputs_embeds[None, ...] # If the model scales embeddings inside the input embedding layer we must # ensure they are scaled here since VocabParallelEmbedding will not do it if ( self.embed_scale is not None and input_ids is not None and inputs_embeds is None ): inputs_embeds = self.embed_input_ids(input_ids) input_ids = None if self.model_config.uses_mrope: position_ids = positions[:, None] else: position_ids = positions[None, ...] outputs = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, use_cache=False, position_ids=position_ids, attention_instances=self.attention_instances, return_dict=False, **self._output_aux_hidden_states_kwargs, **kwargs, ) # We must remove the batch dimension from these outputs hidden_states = outputs[0][0, ...] if self._output_aux_hidden_states_kwargs: aux_hidden_states = [x[0][0, ...] for x in outputs[1:]] if not self.pp_group.is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) if self._output_aux_hidden_states_kwargs and len(aux_hidden_states) > 0: return hidden_states, aux_hidden_states return hidden_states def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]], ) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=self.skip_prefixes, skip_substrs=self.skip_substrs, ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, ignore_unexpected_suffixes=self.ignore_unexpected_suffixes, ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @staticmethod def check_version(min_version: str, feature: str): installed = Version(transformers.__version__) required = Version(min_version) if installed < required: raise ImportError( f"Transformers modeling backend requires transformers>={required} " f"for {feature}, but got {installed}" ) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.check_version("5.2.0", "Eagle3 support") from transformers.utils.output_capturing import ( OutputRecorder, maybe_install_capturing_hooks, ) # The default value in PreTrainedModel is None if self.model._can_record_outputs is None: self.model._can_record_outputs = {} target_class = self._target_class for layer in layers: # layer - 1 because we want the input to the layer layer_name = self._layer_names[layer - 1] layer_key = f"aux_hidden_state_{layer}" aux_hidden_state_i = OutputRecorder(target_class, layer_name=layer_name) self.model._can_record_outputs[layer_key] = aux_hidden_state_i self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True # Ensure that the capture hooks are installed before dynamo traces the model maybe_install_capturing_hooks(self.model) def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = self.text_config.num_hidden_layers return (2, num_layers // 2, num_layers - 3)