[Model] Extend collect_children and no_init_weights contexts (#32757)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-22 16:20:27 +08:00
committed by GitHub
parent 1bf1a34b19
commit 2b8a38b6d6
20 changed files with 444 additions and 257 deletions

View File

@@ -2,11 +2,27 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.utils.collection_utils import swap_dict_values
from vllm.utils.collection_utils import common_prefix, swap_dict_values
@pytest.mark.parametrize(
"obj,key1,key2",
("inputs", "expected_output"),
[
([""], ""),
(["a"], "a"),
(["a", "b"], ""),
(["a", "ab"], "a"),
(["a", "ab", "b"], ""),
(["abc", "a", "ab"], "a"),
(["aba", "abc", "ab"], "ab"),
],
)
def test_common_prefix(inputs, expected_output):
assert common_prefix(inputs) == expected_output
@pytest.mark.parametrize(
("obj", "key1", "key2"),
[
# Tests for both keys exist
({1: "a", 2: "b"}, 1, 2),

View File

@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__)
@@ -165,11 +165,7 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
from vllm.model_executor.models.adapters import (
as_embedding_model,
as_seq_cls_model,
try_create_mm_pooling_model_cls,
)
from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model
architectures = getattr(model_config.hf_config, "architectures", [])
@@ -189,15 +185,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
)
convert_type = model_config.convert_type
if convert_type != "none" and supports_multimodal(model_cls):
logger.debug_once("Detected conversion of Multi Modal model.")
converted = try_create_mm_pooling_model_cls(model_cls)
if converted is not None:
logger.debug_once("Creating wrapper class to forward pooler.")
return converted, arch
else:
logger.debug_once("Attempting direct conversion.")
if convert_type == "none":
pass
elif convert_type == "embed":

View File

@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import inspect
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypeVar, cast
@@ -18,10 +16,12 @@ from vllm.transformers_utils.config import (
)
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
from .interfaces import supports_multimodal
from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import Pooler
_T = TypeVar("_T", bound=type[nn.Module])
@@ -124,41 +124,12 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
return model_name + pooling_suffix
def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T:
class CallVisitor(ast.NodeVisitor):
def __init__(self):
self.calls = []
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
self.calls.append(node.func.id)
self.generic_visit(node)
visitor = CallVisitor()
visitor.visit(ast.parse(inspect.getsource(orig_cls)))
if "init_vllm_registered_model" not in visitor.calls:
return None
class ModelForPooling(orig_cls, VllmModelForPooling):
is_pooling_model = True
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.pooler = self.get_language_model().pooler
return ModelForPooling # type: ignore
def _create_pooling_model_cls(orig_cls: _T) -> _T:
# Lazy import
from .utils import AutoWeightsLoader, WeightsMapper
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from .utils import AutoWeightsLoader, StageMissingLayer, no_init_weights
class ModelForPooling(orig_cls, VllmModelForPooling):
is_pooling_model = True
@@ -170,69 +141,84 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T:
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
with no_init_weights(
self,
lambda mod: StageMissingLayer("output", mod),
targets=(LogitsProcessor, ParallelLMHead),
):
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# Used by SEQ_CLS_LOAD_METHODS
self.vllm_config = vllm_config
# These are not used in pooling models
objects_to_clean = [self]
if language_model := getattr(self, "language_model", None):
objects_to_clean.append(language_model)
for obj in objects_to_clean:
for attr in ("lm_head", "logits_processor"):
if hasattr(obj, attr):
delattr(obj, attr)
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "pooler", None):
self._init_pooler(vllm_config, prefix=prefix)
pooler = getattr(self, "pooler", None)
if not pooler and supports_multimodal(self):
# Try to get the pooler from the LM backbone
language_model = self.get_language_model()
if hasattr(language_model, "pooler"):
pooler = language_model.pooler
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
if not pooler:
pooler = self._init_pooler(vllm_config, prefix=prefix)
self.pooler = pooler
def _init_pooler(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> "Pooler":
raise NotImplementedError
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
load_lm_head: bool = False,
):
# TODO: Support uninitialized params tracking
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
# For most pooling models: We have deleted this attribute, so don't load it.
# For converting an LLM into a seq cls model, we need the lm_head.
if not load_lm_head:
weights = (
(name, data)
for name, data in weights
if not name.startswith("lm_head.")
# We support loading from both `*ForCausalLM` and `*Model`
candidate_prefixes = ["", "model."]
target_prefix = ""
seen_weights = list[tuple[str, torch.Tensor]]()
for name, loaded_weight in weights:
seen_weights.append((name, loaded_weight))
try:
target_prefix = next(
prefix
for prefix in candidate_prefixes
if prefix + name in params_dict
)
break
except StopIteration:
# The weight might not exist on the model
# (to be handled by AutoWeightsLoader)
pass
if target_prefix:
target_model = self
for attr in target_prefix.split("."):
if attr:
target_model = getattr(self, attr)
logger.info(
"Mapping weights to %s as they are "
"relative to this model instead of %s.",
target_model._get_name(),
self._get_name(),
)
# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
# Whether only `self.model` contains parameters
model_is_only_param = all(
name == "model" or next(child.parameters(), None) is None
for name, child in self.named_children()
)
mapped_weights = (
(target_prefix + name, weight)
for name, weight in (*seen_weights, *weights)
)
if model_is_only_param:
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = mapper.apply(weights)
loaded_params = self.model.load_weights(weights)
loaded_params = {f"model.{name}" for name in loaded_params}
return loaded_params
# For most other models
if hasattr(orig_cls, "load_weights"):
return orig_cls.load_weights(self, weights) # type: ignore
# Fallback
else:
def default_load_weights(weights):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
load_weights = getattr(super(), "load_weights", default_load_weights)
return load_weights(mapped_weights)
return ModelForPooling # type: ignore
@@ -255,11 +241,15 @@ def as_embedding_model(cls: _T) -> _T:
from vllm.model_executor.layers.pooler import DispatchPooler
class ModelForEmbedding(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
def _init_pooler(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> "Pooler":
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_embedding(pooler_config)
return DispatchPooler.for_embedding(pooler_config)
ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
@@ -292,7 +282,11 @@ def as_seq_cls_model(cls: _T) -> _T:
class ModelForSequenceClassification(
_create_pooling_model_cls(cls), SupportsCrossEncoding
):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
def _init_pooler(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> "Pooler":
text_config = vllm_config.model_config.hf_config.get_text_config()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
@@ -310,9 +304,7 @@ def as_seq_cls_model(cls: _T) -> _T:
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_seq_cls(
pooler_config, classifier=self.score
)
return DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
hf_config = self.config
@@ -424,7 +416,7 @@ def load_weights_using_from_2_way_softmax(
pooling_model_cls = next(
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
)
loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True)
loaded_weights = pooling_model_cls.load_weights(model, weights)
from vllm.tokenizers import get_tokenizer

View File

@@ -44,11 +44,11 @@ from .interfaces import (
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
TowerMissingLayer,
)
from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
StageMissingLayer,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
@@ -426,9 +426,9 @@ class BagelForConditionalGeneration(
hidden_size=llm_hidden_size,
)
else:
self.vit_model = TowerMissingLayer("image")
self.connector = TowerMissingLayer("image")
self.vit_pos_embed = TowerMissingLayer("image")
self.vit_model = StageMissingLayer("image_tower")
self.connector = StageMissingLayer("image_tower")
self.vit_pos_embed = StageMissingLayer("image_tower")
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors

View File

@@ -935,9 +935,20 @@ class ChameleonForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.model = ChameleonModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
with self._mark_composite_model(
vllm_config,
language_targets=(
ChameleonDecoderLayer
if not self.config.swin_norm
else ChameleonSwinDecoderLayer
),
tower_targets={"image": ChameleonVQVAE},
):
self.model = ChameleonModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.lm_head = ParallelLMHead(
config.vocab_size,
@@ -970,9 +981,6 @@ class ChameleonForConditionalGeneration(
resolve_bindings={"h": expected_h, "w": expected_w},
)
def get_language_model(self) -> torch.nn.Module:
return self.model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -539,10 +539,7 @@ class Gemma3ForConditionalGeneration(
)
logit_scale = getattr(config, "logit_scale", 1.0)
if hasattr(self.language_model, "logits_processor"):
# The logits processor can be unset if we're using
# automatic conversion to pooling model.
self.language_model.logits_processor.scale *= logit_scale
self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors

View File

@@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .chatglm import ChatGLMBaseModel, ChatGLMModel
from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
@@ -591,11 +591,16 @@ class GLM4VForCausalLM(
prefix: str = "",
transformer_type: type[GLM4VModel] = GLM4VModel,
) -> None:
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
transformer_type=transformer_type,
)
with self._mark_composite_model(
vllm_config,
language_targets=GLMTransformer,
tower_targets={"image": EVA2CLIPModel},
):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
transformer_type=transformer_type,
)
self.transformer: GLM4VModel
@@ -752,9 +757,6 @@ class GLM4VForCausalLM(
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
def get_language_model(self) -> torch.nn.Module:
return self.transformer
embed_input_ids = SupportsMultiModal.embed_input_ids
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:

View File

@@ -57,7 +57,11 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer,
)
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
)
from .llama import LlamaModel
from .utils import AutoWeightsLoader, maybe_prefix
@@ -604,9 +608,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo
self.config = config
self.multimodal_config = multimodal_config
self.model = Idefics3Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
with self._mark_composite_model(
vllm_config,
language_targets=LlamaModel,
tower_targets={"image": (Idefics3VisionTransformer, Idefics3Connector)},
):
self.model = Idefics3Model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.image_token_id = self.config.image_token_id
self.lm_head = ParallelLMHead(
@@ -669,9 +680,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo
num_patches = image_input["num_patches"]
return [e.flatten(0, 1) for e in image_features.split(num_patches.tolist())]
def get_language_model(self) -> torch.nn.Module:
return self.model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Mapping, MutableSequence
from contextlib import contextmanager, nullcontext
from contextlib import ExitStack, contextmanager, nullcontext
from typing import (
TYPE_CHECKING,
ClassVar,
@@ -25,6 +25,7 @@ from vllm.inputs import TokensPrompt
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.collection_utils import common_prefix
from vllm.utils.func_utils import supports_kw
from .interfaces_base import VllmModel, is_pooling_model
@@ -70,46 +71,8 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
return is_multimodal
class LMMissingLayer(nn.Module):
def make_empty_intermediate_tensors(self, *args, **kwargs):
raise RuntimeError("This module should not be called in MM encoder-only mode")
def __call__(self, *args, **kwargs):
raise RuntimeError("This module should not be called in MM encoder-only mode")
class TowerMissingLayer(nn.Module):
def __init__(self, modalities: set[str] | str) -> None:
if isinstance(modalities, str):
modalities = {modalities}
super().__init__()
self.modalities = modalities
def __call__(self, *args, **kwargs):
raise RuntimeError(
f"This module should not be called when the following "
f"modalities are disabled: {self.modalities}"
)
@contextmanager
def _no_init_weights(module: nn.Module, placeholder: Callable[[], nn.Module]):
"""
Within this context, prevent weight initialization from using device memory and
replace direct child assignments to `module` with the result of `placeholder()`.
"""
def callback(module_, name, submodule):
if module_ is module:
return placeholder()
return submodule
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with torch.device("meta"):
yield
# Cache results of `SupportsMultiModal.get_language_model`
_language_model_by_module = dict[nn.Module, VllmModel]()
@runtime_checkable
@@ -187,31 +150,61 @@ class SupportsMultiModal(Protocol):
Returns:
torch.nn.Module: The core language model component.
"""
# Cached
if self in _language_model_by_module:
return _language_model_by_module[self]
if self._language_model_names:
return getattr(self, self._language_model_names[0])
mod = self
for attr in common_prefix(
[name.split(".") for name in self._language_model_names]
):
if attr:
mod = getattr(mod, attr)
if mod is not self and hasattr(mod, "embed_input_ids"):
_language_model_by_module[self] = mod
return mod
# Fallback
for mod in self.children():
if hasattr(mod, "embed_input_ids"):
_language_model_by_module[self] = mod
return mod
raise NotImplementedError(
f"No language model found in {type(self).__name__}! "
"You should initialize it inside `_mark_language_model`."
"You should initialize it via `_mark_language_model`."
)
@contextmanager
def _mark_language_model(self, vllm_config: VllmConfig):
def _mark_language_model(
self,
vllm_config: VllmConfig,
*,
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
"""
Mark each child module that was assigned to this model
during this context as a language model component.
Mark each child module that was assigned to this model during this context
as a language model component.
Language model components are automatically skipped in `--mm-encoder-only`
mode.
If `targets` is set, instead include descendants that are an instance
of `targets`, even if they aren't direct children.
"""
from .utils import StageMissingLayer, collect_children, no_init_weights
mm_config = vllm_config.model_config.multimodal_config
children_names = list[str]()
def callback(module_, name, submodule):
if module_ is self:
children_names.append(name)
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with collect_children(self, targets=targets) as children_names: # noqa: SIM117
with (
_no_init_weights(self, LMMissingLayer)
no_init_weights(
self,
lambda mod: StageMissingLayer("language_model", mod),
targets=targets,
)
if mm_config.mm_encoder_only
else nullcontext()
):
@@ -220,25 +213,42 @@ class SupportsMultiModal(Protocol):
self._language_model_names = children_names
@contextmanager
def _mark_tower_model(self, vllm_config: VllmConfig, modalities: set[str] | str):
def _mark_tower_model(
self,
vllm_config: VllmConfig,
modalities: set[str] | str,
*,
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
"""
Mark each child module that was assigned to this model
during this context as a tower model component.
Mark each child module that was assigned to this model during this context
as a tower model component.
Tower model components are automatically skipped when `--limit-mm-per-prompt`
is set to zero for all of their modalities.
If `targets` is set, instead include descendants that are an instance
of `targets`, even if they aren't direct children.
"""
from .utils import StageMissingLayer, collect_children, no_init_weights
if isinstance(modalities, str):
modalities = {modalities}
if modalities == {"image", "video"}:
stage_name = "vision_tower"
else:
stage_name = "_".join([*modalities, "tower"])
mm_config = vllm_config.model_config.multimodal_config
children_names = list[str]()
def callback(module_, name, submodule):
if module_ is self:
children_names.append(name)
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
with collect_children(self, targets=targets) as children_names: # noqa: SIM117
with (
_no_init_weights(self, lambda: TowerMissingLayer(modalities))
no_init_weights(
self,
lambda mod: StageMissingLayer(stage_name, mod),
targets=targets,
)
if all(mm_config.get_limit_per_prompt(m) == 0 for m in modalities)
else nullcontext()
):
@@ -246,6 +256,37 @@ class SupportsMultiModal(Protocol):
self._tower_model_names = children_names
@contextmanager
def _mark_composite_model(
self,
vllm_config: VllmConfig,
*,
language_targets: type[nn.Module] | tuple[type[nn.Module], ...],
tower_targets: dict[str, type[nn.Module] | tuple[type[nn.Module], ...]],
):
"""
Composite wrapper over `_mark_language_model` and
`_mark_tower_model` by modality.
"""
with ExitStack() as stack:
stack.enter_context(
self._mark_language_model(
vllm_config,
targets=language_targets,
)
)
for modality, modality_targets in tower_targets.items():
stack.enter_context(
self._mark_tower_model(
vllm_config,
modality,
targets=modality_targets,
)
)
yield
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
"""
Implement this function to enable LoRA support

View File

@@ -41,10 +41,12 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces_base import default_pooling_type
from .utils import (
StageMissingLayer,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
no_init_weights,
)
@@ -413,10 +415,16 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
prefix: str = "",
model_type: type[InternLM2Model] = InternLM2Model,
):
super().__init__(vllm_config=vllm_config, prefix=prefix, model_type=model_type)
for attr in ("output", "logits_processor"):
delattr(self, attr)
with no_init_weights(
self,
lambda mod: StageMissingLayer("output", mod),
targets=(LogitsProcessor, ParallelLMHead),
):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
model_type=model_type,
)
config = vllm_config.model_config.hf_config
self.head_dtype = vllm_config.model_config.head_dtype

View File

@@ -1035,11 +1035,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
)
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vpm = vpm = self.init_vision_module(
self.vpm = self.init_vision_module(
config, quant_config, prefix=maybe_prefix(prefix, "vpm")
)
self.vision_dim = (
vpm.embed_dim if self.version == (2, 0) else vpm.embeddings.embed_dim
self.vpm.embed_dim
if self.version == (2, 0)
else self.vpm.embeddings.embed_dim
)
self.embed_dim = self.config.hidden_size

View File

@@ -70,20 +70,15 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
LMMissingLayer,
MixtureOfExperts,
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
TowerMissingLayer,
)
from .llama4 import Llama4ForCausalLM
from .utils import (
AutoWeightsLoader,
maybe_prefix,
)
from .utils import AutoWeightsLoader, StageMissingLayer, maybe_prefix
from .vision import run_dp_sharded_vision_model
@@ -1024,7 +1019,7 @@ class Llama4ForConditionalGeneration(
renamed = self._rename_weight_for_modelopt_checkpoint(name)
attr = renamed.split(".", 1)[0]
if isinstance(getattr(self, attr), (LMMissingLayer, TowerMissingLayer)):
if isinstance(getattr(self, attr), StageMissingLayer):
continue
if renamed.startswith("language_model."):

View File

@@ -1513,7 +1513,7 @@ class NemotronH_Nano_VL_V2(
self.video_pruning_rate = multimodal_config.video_pruning_rate
with self._mark_language_model(vllm_config):
self.language_model = language_model = init_vllm_registered_model(
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
@@ -1542,7 +1542,7 @@ class NemotronH_Nano_VL_V2(
ReLUSquaredActivation(),
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
)
self.mlp1 = mlp1.to(language_model.config.dtype)
self.mlp1 = mlp1.to(self.language_model.config.dtype)
self.config = config
self.model_config = vllm_config.model_config

View File

@@ -1025,12 +1025,12 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.mlp_AR = Projector(config, config.vision_config)
with self._mark_language_model(vllm_config):
self.language_model = language_model = Ernie4_5ForCausalLM(
self.language_model = Ernie4_5ForCausalLM(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
for layer in language_model.model.layers:
for layer in self.language_model.model.layers:
if not isinstance(layer, PPMissingLayer):
layer.self_attn.rotary_emb.is_neox_style = True

View File

@@ -314,13 +314,14 @@ class PaliGemmaForConditionalGeneration(
config.text_config.architectures = ["Gemma2ForCausalLM"]
with self._mark_language_model(vllm_config):
self.language_model = language_model = init_vllm_registered_model(
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
language_model.logits_processor.scale *= logit_scale
self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors

View File

@@ -461,15 +461,16 @@ class Qwen3VLMoeForConditionalGeneration(
]
with self._mark_language_model(vllm_config):
self.language_model = language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
# Whether to include the gate_up_proj mapping is determined by
# the language model.
self.packed_modules_mapping = (
self.packed_modules_mapping | language_model.packed_modules_mapping
)
# Whether to include the gate_up_proj mapping is determined by
# the language model.
self.packed_modules_mapping = (
self.packed_modules_mapping | self.language_model.packed_modules_mapping
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors

View File

@@ -58,7 +58,7 @@ from .interfaces import (
SupportsMultiModal,
SupportsPP,
)
from .qwen import QWenBaseModel, QWenModel
from .qwen import QWenBaseModel, QWenBlock, QWenModel
class QwenImagePixelInputs(TensorSchema):
@@ -757,11 +757,16 @@ class QwenVLForConditionalGeneration(
prefix: str = "",
transformer_type: type[QwenVLModel] = QwenVLModel,
) -> None:
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
transformer_type=transformer_type,
)
with self._mark_composite_model(
vllm_config,
language_targets=QWenBlock,
tower_targets={"image": VisionTransformer},
):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
transformer_type=transformer_type,
)
self.transformer: QwenVLModel
@@ -795,9 +800,6 @@ class QwenVLForConditionalGeneration(
return self.transformer.visual(image_input["data"])
def get_language_model(self) -> torch.nn.Module:
return self.transformer
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -2,13 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Literal, Protocol, overload
import torch
import torch.nn as nn
from torch.func import functional_call
from torch.nn.modules.module import register_module_module_registration_hook
from transformers import PretrainedConfig
from vllm.config import VllmConfig
@@ -24,11 +26,7 @@ from vllm.model_executor.model_loader.online_quantization import (
support_quantized_model_reload_from_hp_weights,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (
LMMissingLayer,
TowerMissingLayer,
supports_any_eagle,
)
from vllm.model_executor.models.interfaces import supports_any_eagle
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv
@@ -214,8 +212,8 @@ class AutoWeightsLoader:
continue
raise ValueError(
f"Attempted to load nested weight '{weight_qualname}' "
f"into a single parameter '{base_prefix}'"
f"Attempted to load nested weight {weight_qualname!r} "
f"into a single parameter {base_prefix!r}"
)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
@@ -254,7 +252,7 @@ class AutoWeightsLoader:
module: nn.Module,
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[str]:
if isinstance(module, (LMMissingLayer, TowerMissingLayer, PPMissingLayer)):
if isinstance(module, (StageMissingLayer, PPMissingLayer)):
return
# Avoid infinite recursion since this function is typically
@@ -316,9 +314,14 @@ class AutoWeightsLoader:
continue
desc_param_keys = {
base_prefix + k for k, _ in module.named_parameters(recurse=True)
}
msg = (
f"There is no module or parameter named '{prefix}' "
f"in {type(self.module).__name__}"
f"There is no module or parameter named {prefix!r} "
f"in {self.module._get_name()}. "
f"The available parameters belonging to {base_prefix} "
f"({module._get_name()}) are: {desc_param_keys}"
)
raise ValueError(msg)
@@ -496,6 +499,100 @@ def isin_list(
return torch.isin(elements, test_elements)
class StageMissingLayer(nn.Module):
def __init__(self, stage_name: str, module: nn.Module | None = None) -> None:
super().__init__()
self.stage_name = stage_name
# Don't register this as a child module in order to
# avoid missing keys when loading weights
self.__dict__["module"] = module
def __getattr__(self, name: str):
return getattr(self.__dict__["module"], name)
def __call__(self, *args, **kwargs):
raise RuntimeError(f"{self} should not be called")
def extra_repr(self) -> str:
return f"stage_name={self.stage_name!r}"
@contextmanager
def collect_children(
module: nn.Module,
*,
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
"""
Within this context, collect all direct child assignments to `module`,
returning a list of children names that is internally updated until the
context is exited.
If `targets` is set, instead collect descendents of `module`
that are an instance of `targets`, even if they aren't direct children.
"""
children_names = list[str]()
if targets is None:
def hook(module_: nn.Module, name: str, submodule: nn.Module):
if module_ is module:
children_names.append(name)
with register_module_module_registration_hook(hook):
yield children_names
else:
yield children_names
for name, module_ in module.named_modules():
if isinstance(module_, targets):
children_names.append(name)
@contextmanager
def no_init_weights(
module: nn.Module,
placeholder: Callable[[nn.Module], nn.Module],
*,
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
):
"""
Within this context, prevent weight initialization from using device memory and
replace direct child assignments to `module` with the result of `placeholder()`.
If `targets` is set, instead prevent weight initialization and
replace assignments where the child is an instance of `targets`,
even if they aren't direct children of `module`.
"""
if targets is None:
def hook(module_: nn.Module, name: str, submodule: nn.Module):
if module_ is module:
return placeholder(submodule)
return submodule
with register_module_module_registration_hook(hook), torch.device("meta"):
yield
else:
def hook(module_: nn.Module, name: str, submodule: nn.Module):
if isinstance(module_, targets):
submodule.to("meta") # Free memory
if isinstance(submodule, targets):
submodule.to("meta") # Free memory
return placeholder(submodule)
return submodule
# Not all descendents are targeted, so we can't use a blanket
# `torch.device("meta")` context
with register_module_module_registration_hook(hook):
yield
class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module: ...
@@ -627,7 +724,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
missing_layer_names = []
for name, module in model.named_modules():
if isinstance(module, PPMissingLayer):
if isinstance(module, (StageMissingLayer, PPMissingLayer)):
# NOTE: the trailing dot is used to match the prefix of the layer.
# without the dot, we could match a layer that is not missing,
# e.g., 'encoder.layer.1' would match 'encoder.layer.11'
@@ -639,7 +736,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
"""Check if a parameter is missing in a pipeline parallel model."""
if isinstance(model, PPMissingLayer):
if isinstance(model, (StageMissingLayer, PPMissingLayer)):
return True
return any(

View File

@@ -909,7 +909,12 @@ class WhisperForConditionalGeneration(
self.config = config
self.dtype = vllm_config.model_config.dtype
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
with self._mark_composite_model(
vllm_config,
language_targets=WhisperDecoder,
tower_targets={"audio": WhisperEncoder},
):
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
self.proj_out = ParallelLMHead(
config.vocab_size,
@@ -937,9 +942,6 @@ class WhisperForConditionalGeneration(
)
return decoder_outputs
def get_language_model(self) -> torch.nn.Module:
return self.model.decoder
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
# Required as part of SupportsMultiModal interface.
audio_input = self._parse_and_validate_audio_input(**kwargs)

View File

@@ -7,10 +7,10 @@ This is similar in concept to the `collections` module.
"""
from collections import defaultdict
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
from typing import Generic, Literal, TypeVar
from typing_extensions import TypeIs, assert_never
from typing_extensions import TypeIs, assert_never, overload
T = TypeVar("T")
@@ -74,6 +74,34 @@ def is_list_of(
assert_never(check)
@overload
def common_prefix(items: Sequence[str]) -> str: ...
@overload
def common_prefix(items: Sequence[Sequence[T]]) -> Sequence[T]: ...
def common_prefix(items: Sequence[Sequence[T] | str]) -> Sequence[T] | str:
"""Find the longest prefix common to all items."""
if len(items) == 0:
return []
if len(items) == 1:
return items[0]
shortest = min(items, key=len)
if not shortest:
return shortest[:0]
for match_len in range(1, len(shortest) + 1):
match = shortest[:match_len]
for item in items:
if item[:match_len] != match:
return shortest[: match_len - 1]
return shortest
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
"""Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size):