diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index ab52d544c..9c55a42a4 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -205,6 +205,8 @@ def support_torch_compile( if v.annotation in [ torch.Tensor, torch.Tensor | None, + torch.FloatTensor, + torch.FloatTensor | None, IntermediateTensors, IntermediateTensors | None, ]: @@ -346,7 +348,7 @@ def _support_torch_compile( def __init__( self: _T, - *, + *args, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs: Any, @@ -357,11 +359,24 @@ def _support_torch_compile( # NOTE: to support multimodal models (such as encoder), # we may not have vllm_config so we may need to patch it sig = inspect.signature(old_init) + # Check that any positional arguments match the old_init method signature + annotations = [p.annotation for p in sig.parameters.values()] + for arg, annotation in zip(args, annotations): + if annotation is inspect._empty: + continue + if not isinstance(arg, annotation): + init = f"'{type(self).__name__}.__init__'" + arg_type = f"'{type(arg).__name__}'" + raise TypeError( + f"{init} received a positional argument of type {arg_type}, " + "but no parameter of that type was found in the method signature. " + f"Please either annotate {init} or pass it as a keyword argument." + ) if "vllm_config" in sig.parameters: kwargs["vllm_config"] = vllm_config if "prefix" in sig.parameters: kwargs["prefix"] = prefix - old_init(self, **kwargs) + old_init(self, *args, **kwargs) self.vllm_config = vllm_config self.compilation_config = self.vllm_config.compilation_config diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 916c5a002..6a089fdfa 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -495,9 +495,10 @@ class CompilationConfig: If empty list [], no ops are excluded (suitable for full cudagraphs).""" compile_mm_encoder: bool = False """Whether or not to compile the multimodal encoder. - Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models - on selected platforms. Disabled by default until more models - are supported/tested to work.""" + Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models on selected + platforms. It may also work for models loaded with the Transformers modeling backend + if the encoder is compilable. Disabled by default until more models are + supported/tested to work.""" # Vision encoder CUDA graph cudagraph_mm_encoder: bool = False diff --git a/vllm/model_executor/models/transformers/__init__.py b/vllm/model_executor/models/transformers/__init__.py index 93cd8ff50..cb224e5cb 100644 --- a/vllm/model_executor/models/transformers/__init__.py +++ b/vllm/model_executor/models/transformers/__init__.py @@ -16,13 +16,11 @@ # limitations under the License. """Wrapper around `transformers` models""" -from vllm.compilation.decorators import support_torch_compile from vllm.model_executor.models.transformers.base import Base from vllm.model_executor.models.transformers.causal import CausalMixin from vllm.model_executor.models.transformers.legacy import LegacyMixin from vllm.model_executor.models.transformers.moe import MoEMixin from vllm.model_executor.models.transformers.multimodal import ( - DYNAMIC_ARG_DIMS, MultiModalDummyInputsBuilder, MultiModalMixin, MultiModalProcessingInfo, @@ -32,16 +30,13 @@ from vllm.model_executor.models.transformers.pooling import ( EmbeddingMixin, SequenceClassificationMixin, ) -from vllm.model_executor.models.transformers.utils import can_enable_torch_compile from vllm.multimodal import MULTIMODAL_REGISTRY # Text only models -@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForCausalLM(CausalMixin, Base): ... -@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ... @@ -51,9 +46,6 @@ class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ... info=MultiModalProcessingInfo, dummy_inputs=MultiModalDummyInputsBuilder, ) -@support_torch_compile( - dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile -) class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ... @@ -62,20 +54,15 @@ class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ... info=MultiModalProcessingInfo, dummy_inputs=MultiModalDummyInputsBuilder, ) -@support_torch_compile( - dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile -) class TransformersMultiModalMoEForCausalLM( MoEMixin, MultiModalMixin, CausalMixin, Base ): ... # Embedding models -@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ... -@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ... @@ -84,20 +71,15 @@ class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ... info=MultiModalProcessingInfo, dummy_inputs=MultiModalDummyInputsBuilder, ) -@support_torch_compile( - dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile -) class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ... # Sequence classification models -@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForSequenceClassification( SequenceClassificationMixin, LegacyMixin, Base ): ... -@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersMoEForSequenceClassification( SequenceClassificationMixin, MoEMixin, Base ): ... @@ -108,9 +90,6 @@ class TransformersMoEForSequenceClassification( info=MultiModalProcessingInfo, dummy_inputs=MultiModalDummyInputsBuilder, ) -@support_torch_compile( - dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile -) class TransformersMultiModalForSequenceClassification( SequenceClassificationMixin, MultiModalMixin, Base ): ... diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py index d32bfe6ca..8b3ef56c8 100644 --- a/vllm/model_executor/models/transformers/base.py +++ b/vllm/model_executor/models/transformers/base.py @@ -16,6 +16,7 @@ # limitations under the License. """Transformers modeling backend base class.""" +import sys from collections.abc import Callable, Iterable from itertools import chain from operator import attrgetter @@ -29,6 +30,7 @@ from torch import nn from transformers import AutoModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from vllm.compilation.decorators import support_torch_compile from vllm.config.utils import getattr_iter from vllm.distributed import get_pp_group, get_tp_group from vllm.distributed.utils import get_pp_indices @@ -47,6 +49,7 @@ from vllm.model_executor.models.interfaces import ( ) from vllm.model_executor.models.interfaces_base import VllmModel from vllm.model_executor.models.transformers.utils import ( + can_enable_torch_compile, get_feature_request_tip, init_on_device_without_buffers, log_replacement, @@ -117,6 +120,7 @@ class Base( self.config = vllm_config.model_config.hf_config self.text_config = self.config.get_text_config() self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config self.device_config = vllm_config.device_config self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config @@ -155,14 +159,16 @@ class Base( if "gptq" in quant_method_name: self.ignore_unexpected_suffixes.append(".bias") - # Patch config and init on "meta" to delay allocating GPU tensors self._patch_config() + from_config_kwargs = dict( + config=self.config, + dtype=self.model_config.dtype, + trust_remote_code=self.model_config.trust_remote_code, + ) + self._decorate_for_torch_compile(**from_config_kwargs) + # Init on "meta" to delay allocating GPU tensors with init_on_device_without_buffers("meta"): - self.model: PreTrainedModel = AutoModel.from_config( - self.config, - dtype=self.model_config.dtype, - trust_remote_code=self.model_config.trust_remote_code, - ) + self.model: PreTrainedModel = AutoModel.from_config(**from_config_kwargs) # Create weight name to module qualname mapper self._create_hf_to_vllm_mapper() @@ -218,6 +224,82 @@ class Base( if sub_config.dtype != (dtype := self.config.dtype): sub_config.dtype = dtype + def _get_decoder_cls(self, **kwargs: dict) -> type[PreTrainedModel]: + """ + Get the decoder class from the model. + + Args: + kwargs: The kwargs to create the model. + + Returns: + The decoder class. + """ + with torch.device("meta"): + model: PreTrainedModel = AutoModel.from_config(**kwargs) + decoder_cls = type(model.get_decoder()) + logger.debug("Identified decoder class as: %s", decoder_cls) + del model + return decoder_cls + + def _decorate_cls_for_torch_compile( + self, + cls: type[PreTrainedModel], + dynamic_arg_dims: dict[str, int] | None, + enable_if: Callable[["VllmConfig"], bool], + is_encoder: bool, + ): + """ + Decorate `cls` to indicate to vLLM that it supports torch compile. + + Args: + cls: The PreTrainedModel class to decorate. + dynamic_arg_dims: A mapping from argument name to the dynamic dimensions + of the argument. If None, default dynamic arg dims will be used. See + [`support_torch_compile`][vllm.compilation.decorators.support_torch_compile] + for more details. + enable_if: A function which takes in the vLLM config and returns whether + torch compile should be enabled for this class. + is_encoder: Whether the class being decorated is an encoder. + """ + logger.debug( + "Decorating `%s` as %s for torch compile with dynamic_arg_dims of %s", + cls.__name__, + "encoder" if is_encoder else "decoder", + dynamic_arg_dims, + ) + + @support_torch_compile( + dynamic_arg_dims=dynamic_arg_dims, + enable_if=enable_if, + is_encoder=is_encoder, + ) + class SupportTorchCompileWrapper(cls): ... + + # Patch the class in its module + module = sys.modules[cls.__module__] + setattr(module, cls.__name__, SupportTorchCompileWrapper) + + def _decorate_for_torch_compile(self, **kwargs: dict): + """ + Decorate the model's decoder class to indicate to vLLM that it supports torch + compile if `can_enable_torch_compile` is True. + + Args: + kwargs: The kwargs to create the model, which are needed to get the decoder + class. + """ + self._decorate_cls_for_torch_compile( + cls=self._get_decoder_cls(**kwargs), + # Applied to a PreTrainedModel so the batch dimension will exist + dynamic_arg_dims=dict[str, int]( + input_ids=1, # shape: [1, seq_len] + inputs_embeds=1, # shape: [1, seq_len, hidden_size] + position_ids=-1, # shape: [1, seq_len] or [3, 1, seq_len] for mrope + ), + enable_if=can_enable_torch_compile, + is_encoder=False, + ) + def _create_hf_to_vllm_mapper(self): """ Create a WeightsMapper to map checkpoint weight names to module qualnames. @@ -553,11 +635,6 @@ class Base( input_ids = None inputs_embeds = intermediate_tensors["hidden_states"] - if input_ids is not None: - input_ids = input_ids[None, ...] - if inputs_embeds is not None: - inputs_embeds = inputs_embeds[None, ...] - # If the model scales embeddings inside the input embedding layer we must # ensure they are scaled here since VocabParallelEmbedding will not do it if ( @@ -568,22 +645,29 @@ class Base( inputs_embeds = self.embed_input_ids(input_ids) input_ids = None - if self.model_config.uses_mrope: - position_ids = positions[:, None] - else: - position_ids = positions[None, ...] + # Add batch dimension before entering Transformers model + if input_ids is not None and input_ids.ndim == 1: + # [seq_len] -> [1, seq_len] + input_ids = input_ids[None, ...] + if inputs_embeds is not None and inputs_embeds.ndim == 2: + # [seq_len, hidden_size] -> [1, seq_len, hidden_size] + inputs_embeds = inputs_embeds[None, ...] + if positions.ndim == 1: + # [seq_len] -> [1, seq_len] + positions = positions[None, ...] outputs = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, use_cache=False, - position_ids=position_ids, + position_ids=positions, attention_instances=self.attention_instances, return_dict=False, **self._output_aux_hidden_states_kwargs, **kwargs, ) - # We must remove the batch dimension from these outputs + + # Remove batch dimension after exiting Transformers model hidden_states = outputs[0][0, ...] if self._output_aux_hidden_states_kwargs: aux_hidden_states = [x[0][0, ...] for x in outputs[1:]] diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index ddcd91f61..ab6ba91d2 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -20,7 +20,9 @@ from collections.abc import Mapping from typing import TYPE_CHECKING import torch +from transformers import AutoModel +from vllm.compilation.decorators import should_torch_compile_mm_encoder from vllm.config.utils import getattr_iter from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input from vllm.logger import init_logger @@ -46,19 +48,11 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors if TYPE_CHECKING: - from transformers import BatchFeature + from transformers import BatchFeature, PreTrainedModel from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions -DYNAMIC_ARG_DIMS = { - "input_ids": 0, - # set `positions` to last dim to support Qwen-mrope - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, -} - logger = init_logger(__name__) @@ -274,6 +268,66 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): # Skip SupportsMRoPE.__init__ and call the next class in MRO super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix) + def _get_encoder_cls( + self, modality: str = "image", **kwargs: dict + ) -> type["PreTrainedModel"]: + """ + Get the encoder class from the model. + + Args: + kwargs: The kwargs to create the model. + + Returns: + The encoder class. + """ + with torch.device("meta"): + model: PreTrainedModel = AutoModel.from_config(**kwargs) + encoder_cls = type(model.get_encoder(modality=modality)) + logger.debug("Identified encoder class as: %s", encoder_cls) + if type(model) is encoder_cls: + raise ValueError( + "Unable to infer vision encoder class from the model. " + "You must either: update the model so that " + "https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.get_encoder" + " can detect the vision encoder correctly, or remove " + "'compile_mm_encoder'." + ) + del model + return encoder_cls + + def _decorate_for_torch_compile(self, **kwargs: dict): + """ + Decorate the model's decoder and encoder classes to indicate to vLLM that they + support torch compile if `can_enable_torch_compile` and + `should_torch_compile_mm_encoder` are True respectively. + + Args: + kwargs: The kwargs to create the model, which are needed to get the decoder + and encoder classes. + """ + super()._decorate_for_torch_compile(**kwargs) + # Decorate the vision encoder model class to support torch compile if needed + if self.compilation_config.compile_mm_encoder: + self.check_version("5.0.0", "multimodal encoder compilation support") + logger.warning_once( + "Multimodal encoder compilation with the Transformers modeling backend " + "is an experimental feature. It relies on:\n" + "- The vision encoder being torch compilable.\n" + "- All vision encoder tensor inputs must be type hinted as either " + "`torch.Tensor` or `torch.FloatTensor`.\n" + "- The 0-th dimension of all tensor inputs to the vision encoder being " + "the dynamic dimension (i.e., sequence length or number of patches).\n" + "Please report any issues you encounter to help us improve it." + ) + self._decorate_cls_for_torch_compile( + cls=self._get_encoder_cls(**kwargs), + # TODO: properly infer dynamic_arg_dims based on the encoder's forward + # method signature. Currently we assume dim 0 for all tensor inputs. + dynamic_arg_dims=None, + enable_if=should_torch_compile_mm_encoder, + is_encoder=True, + ) + def forward( self, input_ids: torch.Tensor | None, @@ -285,6 +339,10 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): # Gemma3 and PaliGemma needs `token_type_ids` to work correctly # Other models will not have `token_type_ids` in kwargs kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"} + # Positions shape handling for MRoPE models + if self.model_config.uses_mrope: + # [3, seq_len] -> [3, 1, seq_len] + positions = positions[:, None] model_output = super().forward( input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs )