Don't compile vision encoder for Transformers backend (#30518)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -205,6 +205,8 @@ def support_torch_compile(
|
||||
if v.annotation in [
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
torch.FloatTensor,
|
||||
torch.FloatTensor | None,
|
||||
IntermediateTensors,
|
||||
IntermediateTensors | None,
|
||||
]:
|
||||
@@ -346,7 +348,7 @@ def _support_torch_compile(
|
||||
|
||||
def __init__(
|
||||
self: _T,
|
||||
*,
|
||||
*args,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
@@ -357,11 +359,24 @@ def _support_torch_compile(
|
||||
# NOTE: to support multimodal models (such as encoder),
|
||||
# we may not have vllm_config so we may need to patch it
|
||||
sig = inspect.signature(old_init)
|
||||
# Check that any positional arguments match the old_init method signature
|
||||
annotations = [p.annotation for p in sig.parameters.values()]
|
||||
for arg, annotation in zip(args, annotations):
|
||||
if annotation is inspect._empty:
|
||||
continue
|
||||
if not isinstance(arg, annotation):
|
||||
init = f"'{type(self).__name__}.__init__'"
|
||||
arg_type = f"'{type(arg).__name__}'"
|
||||
raise TypeError(
|
||||
f"{init} received a positional argument of type {arg_type}, "
|
||||
"but no parameter of that type was found in the method signature. "
|
||||
f"Please either annotate {init} or pass it as a keyword argument."
|
||||
)
|
||||
if "vllm_config" in sig.parameters:
|
||||
kwargs["vllm_config"] = vllm_config
|
||||
if "prefix" in sig.parameters:
|
||||
kwargs["prefix"] = prefix
|
||||
old_init(self, **kwargs)
|
||||
old_init(self, *args, **kwargs)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = self.vllm_config.compilation_config
|
||||
|
||||
@@ -495,9 +495,10 @@ class CompilationConfig:
|
||||
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
|
||||
compile_mm_encoder: bool = False
|
||||
"""Whether or not to compile the multimodal encoder.
|
||||
Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models
|
||||
on selected platforms. Disabled by default until more models
|
||||
are supported/tested to work."""
|
||||
Currently, this only works for `Qwen2_5_vl` and `mLLaMa4` models on selected
|
||||
platforms. It may also work for models loaded with the Transformers modeling backend
|
||||
if the encoder is compilable. Disabled by default until more models are
|
||||
supported/tested to work."""
|
||||
|
||||
# Vision encoder CUDA graph
|
||||
cudagraph_mm_encoder: bool = False
|
||||
|
||||
@@ -16,13 +16,11 @@
|
||||
# limitations under the License.
|
||||
"""Wrapper around `transformers` models"""
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.model_executor.models.transformers.base import Base
|
||||
from vllm.model_executor.models.transformers.causal import CausalMixin
|
||||
from vllm.model_executor.models.transformers.legacy import LegacyMixin
|
||||
from vllm.model_executor.models.transformers.moe import MoEMixin
|
||||
from vllm.model_executor.models.transformers.multimodal import (
|
||||
DYNAMIC_ARG_DIMS,
|
||||
MultiModalDummyInputsBuilder,
|
||||
MultiModalMixin,
|
||||
MultiModalProcessingInfo,
|
||||
@@ -32,16 +30,13 @@ from vllm.model_executor.models.transformers.pooling import (
|
||||
EmbeddingMixin,
|
||||
SequenceClassificationMixin,
|
||||
)
|
||||
from vllm.model_executor.models.transformers.utils import can_enable_torch_compile
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
|
||||
# Text only models
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersForCausalLM(CausalMixin, Base): ...
|
||||
|
||||
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
|
||||
|
||||
|
||||
@@ -51,9 +46,6 @@ class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
|
||||
info=MultiModalProcessingInfo,
|
||||
dummy_inputs=MultiModalDummyInputsBuilder,
|
||||
)
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
|
||||
)
|
||||
class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
|
||||
|
||||
|
||||
@@ -62,20 +54,15 @@ class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
|
||||
info=MultiModalProcessingInfo,
|
||||
dummy_inputs=MultiModalDummyInputsBuilder,
|
||||
)
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
|
||||
)
|
||||
class TransformersMultiModalMoEForCausalLM(
|
||||
MoEMixin, MultiModalMixin, CausalMixin, Base
|
||||
): ...
|
||||
|
||||
|
||||
# Embedding models
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ...
|
||||
|
||||
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
|
||||
|
||||
|
||||
@@ -84,20 +71,15 @@ class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
|
||||
info=MultiModalProcessingInfo,
|
||||
dummy_inputs=MultiModalDummyInputsBuilder,
|
||||
)
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
|
||||
)
|
||||
class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ...
|
||||
|
||||
|
||||
# Sequence classification models
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersForSequenceClassification(
|
||||
SequenceClassificationMixin, LegacyMixin, Base
|
||||
): ...
|
||||
|
||||
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersMoEForSequenceClassification(
|
||||
SequenceClassificationMixin, MoEMixin, Base
|
||||
): ...
|
||||
@@ -108,9 +90,6 @@ class TransformersMoEForSequenceClassification(
|
||||
info=MultiModalProcessingInfo,
|
||||
dummy_inputs=MultiModalDummyInputsBuilder,
|
||||
)
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
|
||||
)
|
||||
class TransformersMultiModalForSequenceClassification(
|
||||
SequenceClassificationMixin, MultiModalMixin, Base
|
||||
): ...
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# limitations under the License.
|
||||
"""Transformers modeling backend base class."""
|
||||
|
||||
import sys
|
||||
from collections.abc import Callable, Iterable
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
@@ -29,6 +30,7 @@ from torch import nn
|
||||
from transformers import AutoModel
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.distributed import get_pp_group, get_tp_group
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
@@ -47,6 +49,7 @@ from vllm.model_executor.models.interfaces import (
|
||||
)
|
||||
from vllm.model_executor.models.interfaces_base import VllmModel
|
||||
from vllm.model_executor.models.transformers.utils import (
|
||||
can_enable_torch_compile,
|
||||
get_feature_request_tip,
|
||||
init_on_device_without_buffers,
|
||||
log_replacement,
|
||||
@@ -117,6 +120,7 @@ class Base(
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.text_config = self.config.get_text_config()
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
@@ -155,14 +159,16 @@ class Base(
|
||||
if "gptq" in quant_method_name:
|
||||
self.ignore_unexpected_suffixes.append(".bias")
|
||||
|
||||
# Patch config and init on "meta" to delay allocating GPU tensors
|
||||
self._patch_config()
|
||||
from_config_kwargs = dict(
|
||||
config=self.config,
|
||||
dtype=self.model_config.dtype,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
self._decorate_for_torch_compile(**from_config_kwargs)
|
||||
# Init on "meta" to delay allocating GPU tensors
|
||||
with init_on_device_without_buffers("meta"):
|
||||
self.model: PreTrainedModel = AutoModel.from_config(
|
||||
self.config,
|
||||
dtype=self.model_config.dtype,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
self.model: PreTrainedModel = AutoModel.from_config(**from_config_kwargs)
|
||||
|
||||
# Create weight name to module qualname mapper
|
||||
self._create_hf_to_vllm_mapper()
|
||||
@@ -218,6 +224,82 @@ class Base(
|
||||
if sub_config.dtype != (dtype := self.config.dtype):
|
||||
sub_config.dtype = dtype
|
||||
|
||||
def _get_decoder_cls(self, **kwargs: dict) -> type[PreTrainedModel]:
|
||||
"""
|
||||
Get the decoder class from the model.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs to create the model.
|
||||
|
||||
Returns:
|
||||
The decoder class.
|
||||
"""
|
||||
with torch.device("meta"):
|
||||
model: PreTrainedModel = AutoModel.from_config(**kwargs)
|
||||
decoder_cls = type(model.get_decoder())
|
||||
logger.debug("Identified decoder class as: %s", decoder_cls)
|
||||
del model
|
||||
return decoder_cls
|
||||
|
||||
def _decorate_cls_for_torch_compile(
|
||||
self,
|
||||
cls: type[PreTrainedModel],
|
||||
dynamic_arg_dims: dict[str, int] | None,
|
||||
enable_if: Callable[["VllmConfig"], bool],
|
||||
is_encoder: bool,
|
||||
):
|
||||
"""
|
||||
Decorate `cls` to indicate to vLLM that it supports torch compile.
|
||||
|
||||
Args:
|
||||
cls: The PreTrainedModel class to decorate.
|
||||
dynamic_arg_dims: A mapping from argument name to the dynamic dimensions
|
||||
of the argument. If None, default dynamic arg dims will be used. See
|
||||
[`support_torch_compile`][vllm.compilation.decorators.support_torch_compile]
|
||||
for more details.
|
||||
enable_if: A function which takes in the vLLM config and returns whether
|
||||
torch compile should be enabled for this class.
|
||||
is_encoder: Whether the class being decorated is an encoder.
|
||||
"""
|
||||
logger.debug(
|
||||
"Decorating `%s` as %s for torch compile with dynamic_arg_dims of %s",
|
||||
cls.__name__,
|
||||
"encoder" if is_encoder else "decoder",
|
||||
dynamic_arg_dims,
|
||||
)
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims=dynamic_arg_dims,
|
||||
enable_if=enable_if,
|
||||
is_encoder=is_encoder,
|
||||
)
|
||||
class SupportTorchCompileWrapper(cls): ...
|
||||
|
||||
# Patch the class in its module
|
||||
module = sys.modules[cls.__module__]
|
||||
setattr(module, cls.__name__, SupportTorchCompileWrapper)
|
||||
|
||||
def _decorate_for_torch_compile(self, **kwargs: dict):
|
||||
"""
|
||||
Decorate the model's decoder class to indicate to vLLM that it supports torch
|
||||
compile if `can_enable_torch_compile` is True.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs to create the model, which are needed to get the decoder
|
||||
class.
|
||||
"""
|
||||
self._decorate_cls_for_torch_compile(
|
||||
cls=self._get_decoder_cls(**kwargs),
|
||||
# Applied to a PreTrainedModel so the batch dimension will exist
|
||||
dynamic_arg_dims=dict[str, int](
|
||||
input_ids=1, # shape: [1, seq_len]
|
||||
inputs_embeds=1, # shape: [1, seq_len, hidden_size]
|
||||
position_ids=-1, # shape: [1, seq_len] or [3, 1, seq_len] for mrope
|
||||
),
|
||||
enable_if=can_enable_torch_compile,
|
||||
is_encoder=False,
|
||||
)
|
||||
|
||||
def _create_hf_to_vllm_mapper(self):
|
||||
"""
|
||||
Create a WeightsMapper to map checkpoint weight names to module qualnames.
|
||||
@@ -553,11 +635,6 @@ class Base(
|
||||
input_ids = None
|
||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids[None, ...]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[None, ...]
|
||||
|
||||
# If the model scales embeddings inside the input embedding layer we must
|
||||
# ensure they are scaled here since VocabParallelEmbedding will not do it
|
||||
if (
|
||||
@@ -568,22 +645,29 @@ class Base(
|
||||
inputs_embeds = self.embed_input_ids(input_ids)
|
||||
input_ids = None
|
||||
|
||||
if self.model_config.uses_mrope:
|
||||
position_ids = positions[:, None]
|
||||
else:
|
||||
position_ids = positions[None, ...]
|
||||
# Add batch dimension before entering Transformers model
|
||||
if input_ids is not None and input_ids.ndim == 1:
|
||||
# [seq_len] -> [1, seq_len]
|
||||
input_ids = input_ids[None, ...]
|
||||
if inputs_embeds is not None and inputs_embeds.ndim == 2:
|
||||
# [seq_len, hidden_size] -> [1, seq_len, hidden_size]
|
||||
inputs_embeds = inputs_embeds[None, ...]
|
||||
if positions.ndim == 1:
|
||||
# [seq_len] -> [1, seq_len]
|
||||
positions = positions[None, ...]
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=False,
|
||||
position_ids=position_ids,
|
||||
position_ids=positions,
|
||||
attention_instances=self.attention_instances,
|
||||
return_dict=False,
|
||||
**self._output_aux_hidden_states_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
# We must remove the batch dimension from these outputs
|
||||
|
||||
# Remove batch dimension after exiting Transformers model
|
||||
hidden_states = outputs[0][0, ...]
|
||||
if self._output_aux_hidden_states_kwargs:
|
||||
aux_hidden_states = [x[0][0, ...] for x in outputs[1:]]
|
||||
|
||||
@@ -20,7 +20,9 @@ from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
|
||||
from vllm.compilation.decorators import should_torch_compile_mm_encoder
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input
|
||||
from vllm.logger import init_logger
|
||||
@@ -46,19 +48,11 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchFeature, PreTrainedModel
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
|
||||
DYNAMIC_ARG_DIMS = {
|
||||
"input_ids": 0,
|
||||
# set `positions` to last dim to support Qwen-mrope
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
}
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -274,6 +268,66 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
# Skip SupportsMRoPE.__init__ and call the next class in MRO
|
||||
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def _get_encoder_cls(
|
||||
self, modality: str = "image", **kwargs: dict
|
||||
) -> type["PreTrainedModel"]:
|
||||
"""
|
||||
Get the encoder class from the model.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs to create the model.
|
||||
|
||||
Returns:
|
||||
The encoder class.
|
||||
"""
|
||||
with torch.device("meta"):
|
||||
model: PreTrainedModel = AutoModel.from_config(**kwargs)
|
||||
encoder_cls = type(model.get_encoder(modality=modality))
|
||||
logger.debug("Identified encoder class as: %s", encoder_cls)
|
||||
if type(model) is encoder_cls:
|
||||
raise ValueError(
|
||||
"Unable to infer vision encoder class from the model. "
|
||||
"You must either: update the model so that "
|
||||
"https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.get_encoder"
|
||||
" can detect the vision encoder correctly, or remove "
|
||||
"'compile_mm_encoder'."
|
||||
)
|
||||
del model
|
||||
return encoder_cls
|
||||
|
||||
def _decorate_for_torch_compile(self, **kwargs: dict):
|
||||
"""
|
||||
Decorate the model's decoder and encoder classes to indicate to vLLM that they
|
||||
support torch compile if `can_enable_torch_compile` and
|
||||
`should_torch_compile_mm_encoder` are True respectively.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs to create the model, which are needed to get the decoder
|
||||
and encoder classes.
|
||||
"""
|
||||
super()._decorate_for_torch_compile(**kwargs)
|
||||
# Decorate the vision encoder model class to support torch compile if needed
|
||||
if self.compilation_config.compile_mm_encoder:
|
||||
self.check_version("5.0.0", "multimodal encoder compilation support")
|
||||
logger.warning_once(
|
||||
"Multimodal encoder compilation with the Transformers modeling backend "
|
||||
"is an experimental feature. It relies on:\n"
|
||||
"- The vision encoder being torch compilable.\n"
|
||||
"- All vision encoder tensor inputs must be type hinted as either "
|
||||
"`torch.Tensor` or `torch.FloatTensor`.\n"
|
||||
"- The 0-th dimension of all tensor inputs to the vision encoder being "
|
||||
"the dynamic dimension (i.e., sequence length or number of patches).\n"
|
||||
"Please report any issues you encounter to help us improve it."
|
||||
)
|
||||
self._decorate_cls_for_torch_compile(
|
||||
cls=self._get_encoder_cls(**kwargs),
|
||||
# TODO: properly infer dynamic_arg_dims based on the encoder's forward
|
||||
# method signature. Currently we assume dim 0 for all tensor inputs.
|
||||
dynamic_arg_dims=None,
|
||||
enable_if=should_torch_compile_mm_encoder,
|
||||
is_encoder=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
@@ -285,6 +339,10 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly
|
||||
# Other models will not have `token_type_ids` in kwargs
|
||||
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
|
||||
# Positions shape handling for MRoPE models
|
||||
if self.model_config.uses_mrope:
|
||||
# [3, seq_len] -> [3, 1, seq_len]
|
||||
positions = positions[:, None]
|
||||
model_output = super().forward(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user