[Model] Use context managers for encoder- and LM-only mode (#32605)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-20 11:43:38 +08:00
committed by GitHub
parent 6c01ffb897
commit 4753f3bf69
21 changed files with 290 additions and 353 deletions

View File

@@ -529,64 +529,3 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
method = getattr(hf_config, "method", getattr(text_config, "method", None))
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
return SEQ_CLS_LOAD_METHODS[method](model, weights)
def as_mm_encoder_only_model(cls: _T) -> _T:
"""
Subclass an existing vLLM vl model to support mm encoder only for
EPD encoder instances.
"""
if not hasattr(cls, "embed_multimodal"):
# Submodel case: return the original class.
return cls
if not hasattr(cls, "get_language_model_spec"):
raise TypeError(f"{cls} need to implement `get_language_model_spec` method.")
lm_model_cls, lm_attr = cls.get_language_model_spec()
if lm_model_cls is None or lm_attr is None:
raise TypeError(
f"{cls}.get_language_model_spec() must return (lm_model_cls, lm_attr)"
)
class DummyLM(nn.Module):
def __init__(self, *args, **kwargs):
self.make_empty_intermediate_tensors = None
class ModelForMMEncoderOnly(cls):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
self.is_mm_encoder_only_model = True
origin_init = lm_model_cls.__init__
try:
lm_model_cls.__init__ = DummyLM.__init__
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
if hasattr(self, lm_attr):
delattr(self, lm_attr)
finally:
lm_model_cls.__init__ = origin_init
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
from .utils import AutoWeightsLoader
origin_init_ = AutoWeightsLoader.__init__
def _new_init_(self, *args, **kwargs):
origin_init_(self, *args, **kwargs)
self.skip_prefixes = (self.skip_prefixes or []) + [f"{lm_attr}."]
try:
AutoWeightsLoader.__init__ = _new_init_
result = super().load_weights(weights)
finally:
AutoWeightsLoader.__init__ = origin_init_
return result
return ModelForMMEncoderOnly # type: ignore

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Mapping, MutableSequence
from contextlib import contextmanager, nullcontext
from typing import (
TYPE_CHECKING,
ClassVar,
@@ -69,6 +70,46 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
return is_multimodal
class LMMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {}
def make_empty_intermediate_tensors(self, *args, **kwargs):
raise RuntimeError("This module should not be called in MM encoder-only mode")
def __call__(self, *args, **kwargs):
raise RuntimeError("This module should not be called in MM encoder-only mode")
class TowerMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {}
def __init__(self, modalities: set[str]) -> None:
super().__init__()
self.modalities = modalities
def __call__(self, *args, **kwargs):
raise RuntimeError(f"The following modalities are disabled: {self.modalities}")
@contextmanager
def _no_init_weights(module: nn.Module, placeholder: Callable[[], nn.Module]):
"""
Within this context, prevent weight initialization from using device memory and
replace direct child assignments to `module` with the result of `placeholder()`.
"""
def callback(module_, name, submodule):
if module_ is module:
return placeholder()
return submodule
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with torch.device("meta"):
yield
@runtime_checkable
class SupportsMultiModal(Protocol):
"""The interface required for all multi-modal models."""
@@ -105,6 +146,16 @@ class SupportsMultiModal(Protocol):
Set internally by `MultiModalRegistry.register_processor`.
"""
_language_model_names: list[str] = []
"""
Set internally by `_mark_language_model`.
"""
_tower_model_names: list[str] = []
"""
Set internally by `_mark_tower_model`.
"""
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
"""
@@ -134,7 +185,64 @@ class SupportsMultiModal(Protocol):
Returns:
torch.nn.Module: The core language model component.
"""
...
if self._language_model_names:
return getattr(self, self._language_model_names[0])
raise NotImplementedError(
f"No language model found in {type(self).__name__}! "
"You should initialize it inside `_mark_language_model`."
)
@contextmanager
def _mark_language_model(self, vllm_config: VllmConfig):
"""
Mark each child module that was assigned to this model
during this context as a language model component.
"""
mm_config = vllm_config.model_config.multimodal_config
children_names = list[str]()
def callback(module_, name, submodule):
if module_ is self:
children_names.append(name)
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with (
_no_init_weights(self, LMMissingLayer)
if mm_config.mm_encoder_only
else nullcontext()
):
yield
self._language_model_names = children_names
@contextmanager
def _mark_tower_model(self, vllm_config: VllmConfig, modalities: set[str] | str):
"""
Mark each child module that was assigned to this model
during this context as a tower model component.
"""
if isinstance(modalities, str):
modalities = {modalities}
mm_config = vllm_config.model_config.multimodal_config
children_names = list[str]()
def callback(module_, name, submodule):
if module_ is self:
children_names.append(name)
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with (
_no_init_weights(self, lambda: TowerMissingLayer(modalities))
if all(mm_config.get_limit_per_prompt(m) == 0 for m in modalities)
else nullcontext()
):
yield
self._tower_model_names = children_names
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
"""
@@ -154,14 +262,6 @@ class SupportsMultiModal(Protocol):
"""
...
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return None, None
@overload
def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
@@ -299,10 +399,6 @@ def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
return getattr(model, "supports_encoder_tp_data", False)
def supports_mm_encoder_only(model: type[object] | object) -> bool:
return getattr(model, "is_mm_encoder_only_model", False)
@overload
def supports_multimodal_pruning(
model: type[object],

View File

@@ -550,8 +550,7 @@ class LlavaForConditionalGeneration(
):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
@@ -567,15 +566,13 @@ class LlavaForConditionalGeneration(
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
else:
self.vision_tower = None
self.multi_modal_projector = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -631,8 +628,6 @@ class LlavaForConditionalGeneration(
self,
inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
@@ -644,7 +639,6 @@ class LlavaForConditionalGeneration(
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
if isinstance(image_features, torch.Tensor):
@@ -656,9 +650,6 @@ class LlavaForConditionalGeneration(
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
@@ -727,11 +718,7 @@ class LlavaForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.vision_tower is None and self.multi_modal_projector is None:
skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:

View File

@@ -457,8 +457,7 @@ class Mistral3ForConditionalGeneration(
):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
@@ -476,15 +475,13 @@ class Mistral3ForConditionalGeneration(
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
else:
self.vision_tower = None
self.multi_modal_projector = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -534,9 +531,6 @@ class Mistral3ForConditionalGeneration(
image_embeds = (image_embeds,)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
@@ -607,11 +601,7 @@ class Mistral3ForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.vision_tower is None and self.multi_modal_projector is None:
skip_prefixes = ["vision_tower.", "multi_modal_projector."]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:

View File

@@ -70,12 +70,14 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
LMMissingLayer,
MixtureOfExperts,
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
TowerMissingLayer,
)
from .llama4 import Llama4ForCausalLM
from .utils import (
@@ -773,7 +775,8 @@ class Llama4ForConditionalGeneration(
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
from vllm.compilation.backends import set_model_tag
with (
@@ -792,16 +795,15 @@ class Llama4ForConditionalGeneration(
quant_config=None,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
else:
self.vision_model = None
self.multi_modal_projector = None
self.language_model = initialize_model(
vllm_config=vllm_config.with_hf_config(
config.text_config, ["LlamaForCausalLM"]
),
prefix=maybe_prefix(prefix, "language_model"),
model_class=Llama4ForCausalLM,
)
with self._mark_language_model(vllm_config):
self.language_model = initialize_model(
vllm_config=vllm_config.with_hf_config(
config.text_config, ["LlamaForCausalLM"]
),
prefix=maybe_prefix(prefix, "language_model"),
model_class=Llama4ForCausalLM,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -892,9 +894,6 @@ class Llama4ForConditionalGeneration(
for img in vision_embeddings_flat.split(patches_per_image, dim=0)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
@@ -1024,6 +1023,10 @@ class Llama4ForConditionalGeneration(
for name, weight in weights:
renamed = self._rename_weight_for_modelopt_checkpoint(name)
attr = renamed.split(".", 1)[0]
if isinstance(getattr(self, attr), (LMMissingLayer, TowerMissingLayer)):
continue
if renamed.startswith("language_model."):
language_model_weights.append((renamed, weight))
else:
@@ -1133,10 +1136,6 @@ class Llama4ForConditionalGeneration(
weights
)
# Skip loading vision model and projector if they're not initialized.
if self.vision_model is None and self.multi_modal_projector is None:
other_weights = []
# Handle expert scale parameters
regular_weights, expert_scale_weights, updated_params_from_experts = (
self._handle_expert_scale_broadcasting(language_model_weights, params_dict)

View File

@@ -239,7 +239,7 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
multimodal_config.is_multimodal_pruning_enabled()
)
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.visual = OpenCUAVisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
@@ -247,15 +247,14 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors

View File

@@ -398,13 +398,14 @@ class PixtralForConditionalGeneration(
self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.vision_encoder = VisionTransformer(self.vision_args)
self.pre_mm_projector_norm = (
RMSNorm(self.vision_args.hidden_size, eps=1e-5)
@@ -423,11 +424,6 @@ class PixtralForConditionalGeneration(
self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size
)
else:
self.vision_encoder = None
self.pre_mm_projector_norm = None
self.patch_merger = None
self.vision_language_adapter = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -449,10 +445,6 @@ class PixtralForConditionalGeneration(
self,
image_input: PixtralImagePixelInputs,
) -> tuple[torch.Tensor, ...]:
assert (
self.vision_encoder is not None and self.vision_language_adapter is not None
)
images = image_input["images"]
image_features = self.vision_encoder(images)
feature_sizes = [image_feature.shape[0] for image_feature in image_features]
@@ -477,9 +469,6 @@ class PixtralForConditionalGeneration(
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -822,6 +822,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config
self.config = thinker_config
self.multimodal_config = multimodal_config
self.quant_config = quant_config
# force "use_flash_attention_2=True" to audio tower to align
# the results.
@@ -836,14 +837,10 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
"in the audio tower part."
)
if multimodal_config.get_limit_per_prompt("audio"):
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
else:
self.audio_tower = None
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen2_5_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
@@ -851,16 +848,14 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
else:
self.visual = None
self.quant_config = quant_config
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
hf_config=thinker_config.text_config,
architectures=["Qwen2ForCausalLM"],
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
hf_config=thinker_config.text_config,
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -895,9 +890,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
)
return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_mrope_input_positions(
self,
input_tokens: list[int],
@@ -1175,19 +1167,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = ["talker.", "token2wav."]
if self.audio_tower is None:
skip_prefixes.extend(["audio_tower."])
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(
self,
skip_prefixes=skip_prefixes,
)
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded_weights
loader = AutoWeightsLoader(self, skip_prefixes=["talker.", "token2wav."])
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""

View File

@@ -35,7 +35,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BatchFeature, Qwen2ForCausalLM
from transformers import BatchFeature
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
@@ -1145,9 +1145,7 @@ class Qwen2_5_VLForConditionalGeneration(
multimodal_config.is_multimodal_pruning_enabled()
)
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen2_5_VisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
@@ -1155,14 +1153,13 @@ class Qwen2_5_VLForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
else:
self.visual = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -1447,9 +1444,6 @@ class Qwen2_5_VLForConditionalGeneration(
)
return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
@@ -1516,10 +1510,7 @@ class Qwen2_5_VLForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
@@ -1550,11 +1541,3 @@ class Qwen2_5_VLForConditionalGeneration(
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return Qwen2ForCausalLM, "language_model"

View File

@@ -1233,9 +1233,7 @@ class Qwen2VLForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
@@ -1243,14 +1241,13 @@ class Qwen2VLForConditionalGeneration(
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -1371,9 +1368,6 @@ class Qwen2VLForConditionalGeneration(
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
@@ -1437,10 +1431,7 @@ class Qwen2VLForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:

View File

@@ -1277,11 +1277,16 @@ class Qwen3VLForConditionalGeneration(
multimodal_config.is_multimodal_pruning_enabled()
)
if not multimodal_config.get_limit_per_prompt(
"image"
) and not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
@@ -1290,34 +1295,25 @@ class Qwen3VLForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"),
)
self.language_model = Qwen3LLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
# register buffer for deepstack
if self.use_deepstack:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
with self._mark_language_model(vllm_config):
self.language_model = Qwen3LLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
# register buffer for deepstack
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
@@ -1893,9 +1889,6 @@ class Qwen3VLForConditionalGeneration(
return torch.from_numpy(llm_positions), mrope_position_delta
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
@@ -2076,10 +2069,7 @@ class Qwen3VLForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
@@ -2110,11 +2100,3 @@ class Qwen3VLForConditionalGeneration(
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return Qwen3LLMForCausalLM, "language_model"

View File

@@ -424,11 +424,16 @@ class Qwen3VLMoeForConditionalGeneration(
multimodal_config.is_multimodal_pruning_enabled()
)
if not multimodal_config.get_limit_per_prompt(
"image"
) and not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
@@ -437,9 +442,21 @@ class Qwen3VLMoeForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"),
)
self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
# register buffer for deepstack
if self.use_deepstack:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
with self._mark_language_model(vllm_config):
self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
# Whether to include the gate_up_proj mapping is determined by
# the language model.
self.packed_modules_mapping = (
@@ -450,25 +467,5 @@ class Qwen3VLMoeForConditionalGeneration(
self.language_model.make_empty_intermediate_tensors
)
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
# register buffer for deepstack
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
# Set MoE hyperparameters
self.set_moe_parameters()

View File

@@ -942,7 +942,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = Step3VisionTransformer(
config.vision_config,
None,
@@ -967,17 +967,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
config.hidden_size,
bias=config.projector_bias,
)
else:
self.vision_model = None
self.vit_downsampler = None
self.vit_downsampler2 = None
self.vit_large_projector = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -1071,9 +1067,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
)
return merged_image_features
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
@@ -1133,15 +1126,5 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
skip_prefixes = []
if self.vision_model is None and self.vit_large_projector is None:
skip_prefixes = [
"vision_model.",
"vit_downsampler.",
"vit_downsampler2.",
"vit_large_projector.",
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded_weights
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@@ -504,7 +504,8 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config.get_limit_per_prompt("image"):
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = PerceptionEncoder(
config.vision_config,
get_act_fn(config.vision_config.hidden_act),
@@ -521,15 +522,13 @@ class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
prefix=maybe_prefix(prefix, "vit_large_projector"),
disable_tp=self.use_data_parallel,
)
else:
self.vision_model = None
self.vit_large_projector = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors

View File

@@ -24,7 +24,11 @@ from vllm.model_executor.model_loader.online_quantization import (
support_quantized_model_reload_from_hp_weights,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import supports_any_eagle
from vllm.model_executor.models.interfaces import (
LMMissingLayer,
TowerMissingLayer,
supports_any_eagle,
)
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv
@@ -250,7 +254,7 @@ class AutoWeightsLoader:
module: nn.Module,
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[str]:
if isinstance(module, PPMissingLayer):
if isinstance(module, (LMMissingLayer, TowerMissingLayer, PPMissingLayer)):
return
# Avoid infinite recursion since this function is typically