From 2b8a38b6d6be4f6e09cc20381c7027e7c35cb6c9 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 22 Jan 2026 16:20:27 +0800 Subject: [PATCH] [Model] Extend `collect_children` and `no_init_weights` contexts (#32757) Signed-off-by: DarkLight1337 --- tests/utils_/test_collection_utils.py | 20 +- vllm/model_executor/model_loader/utils.py | 17 +- vllm/model_executor/models/adapters.py | 174 +++++++++--------- vllm/model_executor/models/bagel.py | 8 +- vllm/model_executor/models/chameleon.py | 20 +- vllm/model_executor/models/gemma3_mm.py | 5 +- vllm/model_executor/models/glm4v.py | 20 +- vllm/model_executor/models/idefics3.py | 22 ++- vllm/model_executor/models/interfaces.py | 171 ++++++++++------- vllm/model_executor/models/internlm2.py | 16 +- vllm/model_executor/models/minicpmv.py | 6 +- vllm/model_executor/models/mllama4.py | 9 +- .../model_executor/models/nano_nemotron_vl.py | 4 +- vllm/model_executor/models/paddleocr_vl.py | 4 +- vllm/model_executor/models/paligemma.py | 5 +- vllm/model_executor/models/qwen3_vl_moe.py | 15 +- vllm/model_executor/models/qwen_vl.py | 20 +- vllm/model_executor/models/utils.py | 123 +++++++++++-- vllm/model_executor/models/whisper.py | 10 +- vllm/utils/collection_utils.py | 32 +++- 20 files changed, 444 insertions(+), 257 deletions(-) diff --git a/tests/utils_/test_collection_utils.py b/tests/utils_/test_collection_utils.py index 19f4a3d1c..578941260 100644 --- a/tests/utils_/test_collection_utils.py +++ b/tests/utils_/test_collection_utils.py @@ -2,11 +2,27 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from vllm.utils.collection_utils import swap_dict_values +from vllm.utils.collection_utils import common_prefix, swap_dict_values @pytest.mark.parametrize( - "obj,key1,key2", + ("inputs", "expected_output"), + [ + ([""], ""), + (["a"], "a"), + (["a", "b"], ""), + (["a", "ab"], "a"), + (["a", "ab", "b"], ""), + (["abc", "a", "ab"], "a"), + (["aba", "abc", "ab"], "ab"), + ], +) +def test_common_prefix(inputs, expected_output): + assert common_prefix(inputs) == expected_output + + +@pytest.mark.parametrize( + ("obj", "key1", "key2"), [ # Tests for both keys exist ({1: "a", 2: "b"}, 1, 2), diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 74b02e4c6..1d67cb835 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal +from vllm.model_executor.models.interfaces import SupportsQuant from vllm.utils.platform_utils import is_pin_memory_available logger = init_logger(__name__) @@ -165,11 +165,7 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]() def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: - from vllm.model_executor.models.adapters import ( - as_embedding_model, - as_seq_cls_model, - try_create_mm_pooling_model_cls, - ) + from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model architectures = getattr(model_config.hf_config, "architectures", []) @@ -189,15 +185,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ) convert_type = model_config.convert_type - if convert_type != "none" and supports_multimodal(model_cls): - logger.debug_once("Detected conversion of Multi Modal model.") - converted = try_create_mm_pooling_model_cls(model_cls) - if converted is not None: - logger.debug_once("Creating wrapper class to forward pooler.") - return converted, arch - else: - logger.debug_once("Attempting direct conversion.") - if convert_type == "none": pass elif convert_type == "embed": diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 719804b78..c36090e8f 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import ast -import inspect from collections.abc import Iterable from typing import TYPE_CHECKING, Any, TypeVar, cast @@ -18,10 +16,12 @@ from vllm.transformers_utils.config import ( ) from vllm.transformers_utils.repo_utils import get_hf_file_bytes +from .interfaces import supports_multimodal from .interfaces_base import VllmModelForPooling, is_pooling_model if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig + from vllm.model_executor.layers.pooler import Pooler _T = TypeVar("_T", bound=type[nn.Module]) @@ -124,41 +124,12 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: return model_name + pooling_suffix -def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T: - class CallVisitor(ast.NodeVisitor): - def __init__(self): - self.calls = [] - - def visit_Call(self, node): - if isinstance(node.func, ast.Name): - self.calls.append(node.func.id) - self.generic_visit(node) - - visitor = CallVisitor() - visitor.visit(ast.parse(inspect.getsource(orig_cls))) - if "init_vllm_registered_model" not in visitor.calls: - return None - - class ModelForPooling(orig_cls, VllmModelForPooling): - is_pooling_model = True - - def __init__( - self, - *, - vllm_config: "VllmConfig", - prefix: str = "", - **kwargs: Any, - ) -> None: - super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - - self.pooler = self.get_language_model().pooler - - return ModelForPooling # type: ignore - - def _create_pooling_model_cls(orig_cls: _T) -> _T: # Lazy import - from .utils import AutoWeightsLoader, WeightsMapper + from vllm.model_executor.layers.logits_processor import LogitsProcessor + from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead + + from .utils import AutoWeightsLoader, StageMissingLayer, no_init_weights class ModelForPooling(orig_cls, VllmModelForPooling): is_pooling_model = True @@ -170,69 +141,84 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T: prefix: str = "", **kwargs: Any, ) -> None: - super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + with no_init_weights( + self, + lambda mod: StageMissingLayer("output", mod), + targets=(LogitsProcessor, ParallelLMHead), + ): + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + # Used by SEQ_CLS_LOAD_METHODS self.vllm_config = vllm_config - # These are not used in pooling models - objects_to_clean = [self] - if language_model := getattr(self, "language_model", None): - objects_to_clean.append(language_model) - - for obj in objects_to_clean: - for attr in ("lm_head", "logits_processor"): - if hasattr(obj, attr): - delattr(obj, attr) - # If the model already defines a pooler instance, don't overwrite it - if not getattr(self, "pooler", None): - self._init_pooler(vllm_config, prefix=prefix) + pooler = getattr(self, "pooler", None) + if not pooler and supports_multimodal(self): + # Try to get the pooler from the LM backbone + language_model = self.get_language_model() + if hasattr(language_model, "pooler"): + pooler = language_model.pooler - def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): + if not pooler: + pooler = self._init_pooler(vllm_config, prefix=prefix) + + self.pooler = pooler + + def _init_pooler( + self, + vllm_config: "VllmConfig", + prefix: str = "", + ) -> "Pooler": raise NotImplementedError - def load_weights( - self, - weights: Iterable[tuple[str, torch.Tensor]], - load_lm_head: bool = False, - ): - # TODO: Support uninitialized params tracking + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) - # For most pooling models: We have deleted this attribute, so don't load it. - # For converting an LLM into a seq cls model, we need the lm_head. - if not load_lm_head: - weights = ( - (name, data) - for name, data in weights - if not name.startswith("lm_head.") + # We support loading from both `*ForCausalLM` and `*Model` + candidate_prefixes = ["", "model."] + target_prefix = "" + + seen_weights = list[tuple[str, torch.Tensor]]() + for name, loaded_weight in weights: + seen_weights.append((name, loaded_weight)) + + try: + target_prefix = next( + prefix + for prefix in candidate_prefixes + if prefix + name in params_dict + ) + break + except StopIteration: + # The weight might not exist on the model + # (to be handled by AutoWeightsLoader) + pass + + if target_prefix: + target_model = self + for attr in target_prefix.split("."): + if attr: + target_model = getattr(self, attr) + + logger.info( + "Mapping weights to %s as they are " + "relative to this model instead of %s.", + target_model._get_name(), + self._get_name(), ) - # If `*ForCausalLM` defines `load_weights` on the inner model - # and there are no other inner modules with parameters, - # we support loading from both `*Model` and `*ForCausalLM` - if hasattr(self, "model") and hasattr(self.model, "load_weights"): - # Whether only `self.model` contains parameters - model_is_only_param = all( - name == "model" or next(child.parameters(), None) is None - for name, child in self.named_children() - ) + mapped_weights = ( + (target_prefix + name, weight) + for name, weight in (*seen_weights, *weights) + ) - if model_is_only_param: - mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - weights = mapper.apply(weights) - - loaded_params = self.model.load_weights(weights) - loaded_params = {f"model.{name}" for name in loaded_params} - return loaded_params - - # For most other models - if hasattr(orig_cls, "load_weights"): - return orig_cls.load_weights(self, weights) # type: ignore - # Fallback - else: + def default_load_weights(weights): loader = AutoWeightsLoader(self) return loader.load_weights(weights) + load_weights = getattr(super(), "load_weights", default_load_weights) + return load_weights(mapped_weights) + return ModelForPooling # type: ignore @@ -255,11 +241,15 @@ def as_embedding_model(cls: _T) -> _T: from vllm.model_executor.layers.pooler import DispatchPooler class ModelForEmbedding(_create_pooling_model_cls(cls)): - def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): + def _init_pooler( + self, + vllm_config: "VllmConfig", + prefix: str = "", + ) -> "Pooler": pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler.for_embedding(pooler_config) + return DispatchPooler.for_embedding(pooler_config) ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding") @@ -292,7 +282,11 @@ def as_seq_cls_model(cls: _T) -> _T: class ModelForSequenceClassification( _create_pooling_model_cls(cls), SupportsCrossEncoding ): - def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): + def _init_pooler( + self, + vllm_config: "VllmConfig", + prefix: str = "", + ) -> "Pooler": text_config = vllm_config.model_config.hf_config.get_text_config() model_config = vllm_config.model_config quant_config = vllm_config.quant_config @@ -310,9 +304,7 @@ def as_seq_cls_model(cls: _T) -> _T: pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler.for_seq_cls( - pooler_config, classifier=self.score - ) + return DispatchPooler.for_seq_cls(pooler_config, classifier=self.score) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): hf_config = self.config @@ -424,7 +416,7 @@ def load_weights_using_from_2_way_softmax( pooling_model_cls = next( x for x in type(model).__mro__ if x.__name__ == "ModelForPooling" ) - loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True) + loaded_weights = pooling_model_cls.load_weights(model, weights) from vllm.tokenizers import get_tokenizer diff --git a/vllm/model_executor/models/bagel.py b/vllm/model_executor/models/bagel.py index 335a99509..0d28a9a53 100644 --- a/vllm/model_executor/models/bagel.py +++ b/vllm/model_executor/models/bagel.py @@ -44,11 +44,11 @@ from .interfaces import ( SupportsLoRA, SupportsMultiModal, SupportsPP, - TowerMissingLayer, ) from .siglip import SiglipVisionModel from .utils import ( AutoWeightsLoader, + StageMissingLayer, WeightsMapper, init_vllm_registered_model, maybe_prefix, @@ -426,9 +426,9 @@ class BagelForConditionalGeneration( hidden_size=llm_hidden_size, ) else: - self.vit_model = TowerMissingLayer("image") - self.connector = TowerMissingLayer("image") - self.vit_pos_embed = TowerMissingLayer("image") + self.vit_model = StageMissingLayer("image_tower") + self.connector = StageMissingLayer("image_tower") + self.vit_pos_embed = StageMissingLayer("image_tower") self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 5b8ecc0b6..339ecaeb7 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -935,9 +935,20 @@ class ChameleonForConditionalGeneration( multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config - self.model = ChameleonModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) + + with self._mark_composite_model( + vllm_config, + language_targets=( + ChameleonDecoderLayer + if not self.config.swin_norm + else ChameleonSwinDecoderLayer + ), + tower_targets={"image": ChameleonVQVAE}, + ): + self.model = ChameleonModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) self.lm_head = ParallelLMHead( config.vocab_size, @@ -970,9 +981,6 @@ class ChameleonForConditionalGeneration( resolve_bindings={"h": expected_h, "w": expected_w}, ) - def get_language_model(self) -> torch.nn.Module: - return self.model - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 52ae40ba3..95e372291 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -539,10 +539,7 @@ class Gemma3ForConditionalGeneration( ) logit_scale = getattr(config, "logit_scale", 1.0) - if hasattr(self.language_model, "logits_processor"): - # The logits processor can be unset if we're using - # automatic conversion to pooling model. - self.language_model.logits_processor.scale *= logit_scale + self.language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 615731bda..fca7c49cc 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .chatglm import ChatGLMBaseModel, ChatGLMModel +from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, @@ -591,11 +591,16 @@ class GLM4VForCausalLM( prefix: str = "", transformer_type: type[GLM4VModel] = GLM4VModel, ) -> None: - super().__init__( - vllm_config=vllm_config, - prefix=prefix, - transformer_type=transformer_type, - ) + with self._mark_composite_model( + vllm_config, + language_targets=GLMTransformer, + tower_targets={"image": EVA2CLIPModel}, + ): + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + transformer_type=transformer_type, + ) self.transformer: GLM4VModel @@ -752,9 +757,6 @@ class GLM4VForCausalLM( mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta - def get_language_model(self) -> torch.nn.Module: - return self.transformer - embed_input_ids = SupportsMultiModal.embed_input_ids def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 9e8be5ca0..fc88f07be 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -57,7 +57,11 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import ( Idefics2VisionTransformer as Idefics3VisionTransformer, ) -from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, +) from .llama import LlamaModel from .utils import AutoWeightsLoader, maybe_prefix @@ -604,9 +608,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo self.config = config self.multimodal_config = multimodal_config - self.model = Idefics3Model( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) + with self._mark_composite_model( + vllm_config, + language_targets=LlamaModel, + tower_targets={"image": (Idefics3VisionTransformer, Idefics3Connector)}, + ): + self.model = Idefics3Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + self.image_token_id = self.config.image_token_id self.lm_head = ParallelLMHead( @@ -669,9 +680,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo num_patches = image_input["num_patches"] return [e.flatten(0, 1) for e in image_features.split(num_patches.tolist())] - def get_language_model(self) -> torch.nn.Module: - return self.model - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 809395cf3..013ece5fe 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable, Iterable, Mapping, MutableSequence -from contextlib import contextmanager, nullcontext +from contextlib import ExitStack, contextmanager, nullcontext from typing import ( TYPE_CHECKING, ClassVar, @@ -25,6 +25,7 @@ from vllm.inputs import TokensPrompt from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.utils.collection_utils import common_prefix from vllm.utils.func_utils import supports_kw from .interfaces_base import VllmModel, is_pooling_model @@ -70,46 +71,8 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor: return is_multimodal -class LMMissingLayer(nn.Module): - def make_empty_intermediate_tensors(self, *args, **kwargs): - raise RuntimeError("This module should not be called in MM encoder-only mode") - - def __call__(self, *args, **kwargs): - raise RuntimeError("This module should not be called in MM encoder-only mode") - - -class TowerMissingLayer(nn.Module): - def __init__(self, modalities: set[str] | str) -> None: - if isinstance(modalities, str): - modalities = {modalities} - - super().__init__() - - self.modalities = modalities - - def __call__(self, *args, **kwargs): - raise RuntimeError( - f"This module should not be called when the following " - f"modalities are disabled: {self.modalities}" - ) - - -@contextmanager -def _no_init_weights(module: nn.Module, placeholder: Callable[[], nn.Module]): - """ - Within this context, prevent weight initialization from using device memory and - replace direct child assignments to `module` with the result of `placeholder()`. - """ - - def callback(module_, name, submodule): - if module_ is module: - return placeholder() - - return submodule - - with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117 - with torch.device("meta"): - yield +# Cache results of `SupportsMultiModal.get_language_model` +_language_model_by_module = dict[nn.Module, VllmModel]() @runtime_checkable @@ -187,31 +150,61 @@ class SupportsMultiModal(Protocol): Returns: torch.nn.Module: The core language model component. """ + # Cached + if self in _language_model_by_module: + return _language_model_by_module[self] + if self._language_model_names: - return getattr(self, self._language_model_names[0]) + mod = self + for attr in common_prefix( + [name.split(".") for name in self._language_model_names] + ): + if attr: + mod = getattr(mod, attr) + + if mod is not self and hasattr(mod, "embed_input_ids"): + _language_model_by_module[self] = mod + return mod + + # Fallback + for mod in self.children(): + if hasattr(mod, "embed_input_ids"): + _language_model_by_module[self] = mod + return mod raise NotImplementedError( f"No language model found in {type(self).__name__}! " - "You should initialize it inside `_mark_language_model`." + "You should initialize it via `_mark_language_model`." ) @contextmanager - def _mark_language_model(self, vllm_config: VllmConfig): + def _mark_language_model( + self, + vllm_config: VllmConfig, + *, + targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None, + ): """ - Mark each child module that was assigned to this model - during this context as a language model component. + Mark each child module that was assigned to this model during this context + as a language model component. + + Language model components are automatically skipped in `--mm-encoder-only` + mode. + + If `targets` is set, instead include descendants that are an instance + of `targets`, even if they aren't direct children. """ + from .utils import StageMissingLayer, collect_children, no_init_weights + mm_config = vllm_config.model_config.multimodal_config - children_names = list[str]() - - def callback(module_, name, submodule): - if module_ is self: - children_names.append(name) - - with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117 + with collect_children(self, targets=targets) as children_names: # noqa: SIM117 with ( - _no_init_weights(self, LMMissingLayer) + no_init_weights( + self, + lambda mod: StageMissingLayer("language_model", mod), + targets=targets, + ) if mm_config.mm_encoder_only else nullcontext() ): @@ -220,25 +213,42 @@ class SupportsMultiModal(Protocol): self._language_model_names = children_names @contextmanager - def _mark_tower_model(self, vllm_config: VllmConfig, modalities: set[str] | str): + def _mark_tower_model( + self, + vllm_config: VllmConfig, + modalities: set[str] | str, + *, + targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None, + ): """ - Mark each child module that was assigned to this model - during this context as a tower model component. + Mark each child module that was assigned to this model during this context + as a tower model component. + + Tower model components are automatically skipped when `--limit-mm-per-prompt` + is set to zero for all of their modalities. + + If `targets` is set, instead include descendants that are an instance + of `targets`, even if they aren't direct children. """ + from .utils import StageMissingLayer, collect_children, no_init_weights + if isinstance(modalities, str): modalities = {modalities} + if modalities == {"image", "video"}: + stage_name = "vision_tower" + else: + stage_name = "_".join([*modalities, "tower"]) + mm_config = vllm_config.model_config.multimodal_config - children_names = list[str]() - - def callback(module_, name, submodule): - if module_ is self: - children_names.append(name) - - with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117 + with collect_children(self, targets=targets) as children_names: # noqa: SIM117 with ( - _no_init_weights(self, lambda: TowerMissingLayer(modalities)) + no_init_weights( + self, + lambda mod: StageMissingLayer(stage_name, mod), + targets=targets, + ) if all(mm_config.get_limit_per_prompt(m) == 0 for m in modalities) else nullcontext() ): @@ -246,6 +256,37 @@ class SupportsMultiModal(Protocol): self._tower_model_names = children_names + @contextmanager + def _mark_composite_model( + self, + vllm_config: VllmConfig, + *, + language_targets: type[nn.Module] | tuple[type[nn.Module], ...], + tower_targets: dict[str, type[nn.Module] | tuple[type[nn.Module], ...]], + ): + """ + Composite wrapper over `_mark_language_model` and + `_mark_tower_model` by modality. + """ + with ExitStack() as stack: + stack.enter_context( + self._mark_language_model( + vllm_config, + targets=language_targets, + ) + ) + + for modality, modality_targets in tower_targets.items(): + stack.enter_context( + self._mark_tower_model( + vllm_config, + modality, + targets=modality_targets, + ) + ) + + yield + def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int: """ Implement this function to enable LoRA support diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 45628b4fe..5dec47e09 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -41,10 +41,12 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .interfaces_base import default_pooling_type from .utils import ( + StageMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, + no_init_weights, ) @@ -413,10 +415,16 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): prefix: str = "", model_type: type[InternLM2Model] = InternLM2Model, ): - super().__init__(vllm_config=vllm_config, prefix=prefix, model_type=model_type) - - for attr in ("output", "logits_processor"): - delattr(self, attr) + with no_init_weights( + self, + lambda mod: StageMissingLayer("output", mod), + targets=(LogitsProcessor, ParallelLMHead), + ): + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + model_type=model_type, + ) config = vllm_config.model_config.hf_config self.head_dtype = vllm_config.model_config.head_dtype diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index de76e9abf..d9179250b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1035,11 +1035,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ) with self._mark_tower_model(vllm_config, {"image", "video"}): - self.vpm = vpm = self.init_vision_module( + self.vpm = self.init_vision_module( config, quant_config, prefix=maybe_prefix(prefix, "vpm") ) self.vision_dim = ( - vpm.embed_dim if self.version == (2, 0) else vpm.embeddings.embed_dim + self.vpm.embed_dim + if self.version == (2, 0) + else self.vpm.embeddings.embed_dim ) self.embed_dim = self.config.hidden_size diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 53e1eb036..7db496758 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -70,20 +70,15 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( - LMMissingLayer, MixtureOfExperts, MultiModalEmbeddings, SupportsEagle3, SupportsLoRA, SupportsMultiModal, SupportsPP, - TowerMissingLayer, ) from .llama4 import Llama4ForCausalLM -from .utils import ( - AutoWeightsLoader, - maybe_prefix, -) +from .utils import AutoWeightsLoader, StageMissingLayer, maybe_prefix from .vision import run_dp_sharded_vision_model @@ -1024,7 +1019,7 @@ class Llama4ForConditionalGeneration( renamed = self._rename_weight_for_modelopt_checkpoint(name) attr = renamed.split(".", 1)[0] - if isinstance(getattr(self, attr), (LMMissingLayer, TowerMissingLayer)): + if isinstance(getattr(self, attr), StageMissingLayer): continue if renamed.startswith("language_model."): diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 82f3f1362..aa4745647 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -1513,7 +1513,7 @@ class NemotronH_Nano_VL_V2( self.video_pruning_rate = multimodal_config.video_pruning_rate with self._mark_language_model(vllm_config): - self.language_model = language_model = init_vllm_registered_model( + self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), @@ -1542,7 +1542,7 @@ class NemotronH_Nano_VL_V2( ReLUSquaredActivation(), nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False), ) - self.mlp1 = mlp1.to(language_model.config.dtype) + self.mlp1 = mlp1.to(self.language_model.config.dtype) self.config = config self.model_config = vllm_config.model_config diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index ebccef986..24749c7cf 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -1025,12 +1025,12 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support self.mlp_AR = Projector(config, config.vision_config) with self._mark_language_model(vllm_config): - self.language_model = language_model = Ernie4_5ForCausalLM( + self.language_model = Ernie4_5ForCausalLM( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model"), ) - for layer in language_model.model.layers: + for layer in self.language_model.model.layers: if not isinstance(layer, PPMissingLayer): layer.self_attn.rotary_emb.is_neox_style = True diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 6a9cc5c03..533f060ae 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -314,13 +314,14 @@ class PaliGemmaForConditionalGeneration( config.text_config.architectures = ["Gemma2ForCausalLM"] with self._mark_language_model(vllm_config): - self.language_model = language_model = init_vllm_registered_model( + self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) + logit_scale = getattr(config, "logit_scale", 1.0) - language_model.logits_processor.scale *= logit_scale + self.language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index f65e95279..50b511dd2 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -461,15 +461,16 @@ class Qwen3VLMoeForConditionalGeneration( ] with self._mark_language_model(vllm_config): - self.language_model = language_model = Qwen3MoeLLMForCausalLM( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") + self.language_model = Qwen3MoeLLMForCausalLM( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), ) - # Whether to include the gate_up_proj mapping is determined by - # the language model. - self.packed_modules_mapping = ( - self.packed_modules_mapping | language_model.packed_modules_mapping - ) + # Whether to include the gate_up_proj mapping is determined by + # the language model. + self.packed_modules_mapping = ( + self.packed_modules_mapping | self.language_model.packed_modules_mapping + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 83a2b2b6c..11856dd2d 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -58,7 +58,7 @@ from .interfaces import ( SupportsMultiModal, SupportsPP, ) -from .qwen import QWenBaseModel, QWenModel +from .qwen import QWenBaseModel, QWenBlock, QWenModel class QwenImagePixelInputs(TensorSchema): @@ -757,11 +757,16 @@ class QwenVLForConditionalGeneration( prefix: str = "", transformer_type: type[QwenVLModel] = QwenVLModel, ) -> None: - super().__init__( - vllm_config=vllm_config, - prefix=prefix, - transformer_type=transformer_type, - ) + with self._mark_composite_model( + vllm_config, + language_targets=QWenBlock, + tower_targets={"image": VisionTransformer}, + ): + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + transformer_type=transformer_type, + ) self.transformer: QwenVLModel @@ -795,9 +800,6 @@ class QwenVLForConditionalGeneration( return self.transformer.visual(image_input["data"]) - def get_language_model(self) -> torch.nn.Module: - return self.transformer - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index b22c235cb..d3e1434b7 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -2,13 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any, Literal, Protocol, overload import torch import torch.nn as nn from torch.func import functional_call +from torch.nn.modules.module import register_module_module_registration_hook from transformers import PretrainedConfig from vllm.config import VllmConfig @@ -24,11 +26,7 @@ from vllm.model_executor.model_loader.online_quantization import ( support_quantized_model_reload_from_hp_weights, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import ( - LMMissingLayer, - TowerMissingLayer, - supports_any_eagle, -) +from vllm.model_executor.models.interfaces import supports_any_eagle from vllm.multimodal import NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv @@ -214,8 +212,8 @@ class AutoWeightsLoader: continue raise ValueError( - f"Attempted to load nested weight '{weight_qualname}' " - f"into a single parameter '{base_prefix}'" + f"Attempted to load nested weight {weight_qualname!r} " + f"into a single parameter {base_prefix!r}" ) weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -254,7 +252,7 @@ class AutoWeightsLoader: module: nn.Module, weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[str]: - if isinstance(module, (LMMissingLayer, TowerMissingLayer, PPMissingLayer)): + if isinstance(module, (StageMissingLayer, PPMissingLayer)): return # Avoid infinite recursion since this function is typically @@ -316,9 +314,14 @@ class AutoWeightsLoader: continue + desc_param_keys = { + base_prefix + k for k, _ in module.named_parameters(recurse=True) + } msg = ( - f"There is no module or parameter named '{prefix}' " - f"in {type(self.module).__name__}" + f"There is no module or parameter named {prefix!r} " + f"in {self.module._get_name()}. " + f"The available parameters belonging to {base_prefix} " + f"({module._get_name()}) are: {desc_param_keys}" ) raise ValueError(msg) @@ -496,6 +499,100 @@ def isin_list( return torch.isin(elements, test_elements) +class StageMissingLayer(nn.Module): + def __init__(self, stage_name: str, module: nn.Module | None = None) -> None: + super().__init__() + + self.stage_name = stage_name + + # Don't register this as a child module in order to + # avoid missing keys when loading weights + self.__dict__["module"] = module + + def __getattr__(self, name: str): + return getattr(self.__dict__["module"], name) + + def __call__(self, *args, **kwargs): + raise RuntimeError(f"{self} should not be called") + + def extra_repr(self) -> str: + return f"stage_name={self.stage_name!r}" + + +@contextmanager +def collect_children( + module: nn.Module, + *, + targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None, +): + """ + Within this context, collect all direct child assignments to `module`, + returning a list of children names that is internally updated until the + context is exited. + + If `targets` is set, instead collect descendents of `module` + that are an instance of `targets`, even if they aren't direct children. + """ + children_names = list[str]() + + if targets is None: + + def hook(module_: nn.Module, name: str, submodule: nn.Module): + if module_ is module: + children_names.append(name) + + with register_module_module_registration_hook(hook): + yield children_names + else: + yield children_names + + for name, module_ in module.named_modules(): + if isinstance(module_, targets): + children_names.append(name) + + +@contextmanager +def no_init_weights( + module: nn.Module, + placeholder: Callable[[nn.Module], nn.Module], + *, + targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None, +): + """ + Within this context, prevent weight initialization from using device memory and + replace direct child assignments to `module` with the result of `placeholder()`. + + If `targets` is set, instead prevent weight initialization and + replace assignments where the child is an instance of `targets`, + even if they aren't direct children of `module`. + """ + if targets is None: + + def hook(module_: nn.Module, name: str, submodule: nn.Module): + if module_ is module: + return placeholder(submodule) + + return submodule + + with register_module_module_registration_hook(hook), torch.device("meta"): + yield + else: + + def hook(module_: nn.Module, name: str, submodule: nn.Module): + if isinstance(module_, targets): + submodule.to("meta") # Free memory + if isinstance(submodule, targets): + submodule.to("meta") # Free memory + return placeholder(submodule) + + return submodule + + # Not all descendents are targeted, so we can't use a blanket + # `torch.device("meta")` context + with register_module_module_registration_hook(hook): + yield + + class LayerFn(Protocol): def __call__(self, prefix: str) -> torch.nn.Module: ... @@ -627,7 +724,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]: missing_layer_names = [] for name, module in model.named_modules(): - if isinstance(module, PPMissingLayer): + if isinstance(module, (StageMissingLayer, PPMissingLayer)): # NOTE: the trailing dot is used to match the prefix of the layer. # without the dot, we could match a layer that is not missing, # e.g., 'encoder.layer.1' would match 'encoder.layer.11' @@ -639,7 +736,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]: def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: """Check if a parameter is missing in a pipeline parallel model.""" - if isinstance(model, PPMissingLayer): + if isinstance(model, (StageMissingLayer, PPMissingLayer)): return True return any( diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 8d6726145..7e3d470a5 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -909,7 +909,12 @@ class WhisperForConditionalGeneration( self.config = config self.dtype = vllm_config.model_config.dtype - self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) + with self._mark_composite_model( + vllm_config, + language_targets=WhisperDecoder, + tower_targets={"audio": WhisperEncoder}, + ): + self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) self.proj_out = ParallelLMHead( config.vocab_size, @@ -937,9 +942,6 @@ class WhisperForConditionalGeneration( ) return decoder_outputs - def get_language_model(self) -> torch.nn.Module: - return self.model.decoder - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: # Required as part of SupportsMultiModal interface. audio_input = self._parse_and_validate_audio_input(**kwargs) diff --git a/vllm/utils/collection_utils.py b/vllm/utils/collection_utils.py index 3b19e1bd7..aefaf84ee 100644 --- a/vllm/utils/collection_utils.py +++ b/vllm/utils/collection_utils.py @@ -7,10 +7,10 @@ This is similar in concept to the `collections` module. """ from collections import defaultdict -from collections.abc import Callable, Generator, Hashable, Iterable, Mapping +from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence from typing import Generic, Literal, TypeVar -from typing_extensions import TypeIs, assert_never +from typing_extensions import TypeIs, assert_never, overload T = TypeVar("T") @@ -74,6 +74,34 @@ def is_list_of( assert_never(check) +@overload +def common_prefix(items: Sequence[str]) -> str: ... + + +@overload +def common_prefix(items: Sequence[Sequence[T]]) -> Sequence[T]: ... + + +def common_prefix(items: Sequence[Sequence[T] | str]) -> Sequence[T] | str: + """Find the longest prefix common to all items.""" + if len(items) == 0: + return [] + if len(items) == 1: + return items[0] + + shortest = min(items, key=len) + if not shortest: + return shortest[:0] + + for match_len in range(1, len(shortest) + 1): + match = shortest[:match_len] + for item in items: + if item[:match_len] != match: + return shortest[: match_len - 1] + + return shortest + + def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]: """Yield successive chunk_size chunks from lst.""" for i in range(0, len(lst), chunk_size):