[Model] Extend collect_children and no_init_weights contexts (#32757)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-22 16:20:27 +08:00
committed by GitHub
parent 1bf1a34b19
commit 2b8a38b6d6
20 changed files with 444 additions and 257 deletions

View File

@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import inspect
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypeVar, cast
@@ -18,10 +16,12 @@ from vllm.transformers_utils.config import (
)
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
from .interfaces import supports_multimodal
from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import Pooler
_T = TypeVar("_T", bound=type[nn.Module])
@@ -124,41 +124,12 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
return model_name + pooling_suffix
def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T:
class CallVisitor(ast.NodeVisitor):
def __init__(self):
self.calls = []
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
self.calls.append(node.func.id)
self.generic_visit(node)
visitor = CallVisitor()
visitor.visit(ast.parse(inspect.getsource(orig_cls)))
if "init_vllm_registered_model" not in visitor.calls:
return None
class ModelForPooling(orig_cls, VllmModelForPooling):
is_pooling_model = True
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.pooler = self.get_language_model().pooler
return ModelForPooling # type: ignore
def _create_pooling_model_cls(orig_cls: _T) -> _T:
# Lazy import
from .utils import AutoWeightsLoader, WeightsMapper
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from .utils import AutoWeightsLoader, StageMissingLayer, no_init_weights
class ModelForPooling(orig_cls, VllmModelForPooling):
is_pooling_model = True
@@ -170,69 +141,84 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T:
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
with no_init_weights(
self,
lambda mod: StageMissingLayer("output", mod),
targets=(LogitsProcessor, ParallelLMHead),
):
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# Used by SEQ_CLS_LOAD_METHODS
self.vllm_config = vllm_config
# These are not used in pooling models
objects_to_clean = [self]
if language_model := getattr(self, "language_model", None):
objects_to_clean.append(language_model)
for obj in objects_to_clean:
for attr in ("lm_head", "logits_processor"):
if hasattr(obj, attr):
delattr(obj, attr)
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "pooler", None):
self._init_pooler(vllm_config, prefix=prefix)
pooler = getattr(self, "pooler", None)
if not pooler and supports_multimodal(self):
# Try to get the pooler from the LM backbone
language_model = self.get_language_model()
if hasattr(language_model, "pooler"):
pooler = language_model.pooler
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
if not pooler:
pooler = self._init_pooler(vllm_config, prefix=prefix)
self.pooler = pooler
def _init_pooler(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> "Pooler":
raise NotImplementedError
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
load_lm_head: bool = False,
):
# TODO: Support uninitialized params tracking
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
# For most pooling models: We have deleted this attribute, so don't load it.
# For converting an LLM into a seq cls model, we need the lm_head.
if not load_lm_head:
weights = (
(name, data)
for name, data in weights
if not name.startswith("lm_head.")
# We support loading from both `*ForCausalLM` and `*Model`
candidate_prefixes = ["", "model."]
target_prefix = ""
seen_weights = list[tuple[str, torch.Tensor]]()
for name, loaded_weight in weights:
seen_weights.append((name, loaded_weight))
try:
target_prefix = next(
prefix
for prefix in candidate_prefixes
if prefix + name in params_dict
)
break
except StopIteration:
# The weight might not exist on the model
# (to be handled by AutoWeightsLoader)
pass
if target_prefix:
target_model = self
for attr in target_prefix.split("."):
if attr:
target_model = getattr(self, attr)
logger.info(
"Mapping weights to %s as they are "
"relative to this model instead of %s.",
target_model._get_name(),
self._get_name(),
)
# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
# Whether only `self.model` contains parameters
model_is_only_param = all(
name == "model" or next(child.parameters(), None) is None
for name, child in self.named_children()
)
mapped_weights = (
(target_prefix + name, weight)
for name, weight in (*seen_weights, *weights)
)
if model_is_only_param:
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = mapper.apply(weights)
loaded_params = self.model.load_weights(weights)
loaded_params = {f"model.{name}" for name in loaded_params}
return loaded_params
# For most other models
if hasattr(orig_cls, "load_weights"):
return orig_cls.load_weights(self, weights) # type: ignore
# Fallback
else:
def default_load_weights(weights):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
load_weights = getattr(super(), "load_weights", default_load_weights)
return load_weights(mapped_weights)
return ModelForPooling # type: ignore
@@ -255,11 +241,15 @@ def as_embedding_model(cls: _T) -> _T:
from vllm.model_executor.layers.pooler import DispatchPooler
class ModelForEmbedding(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
def _init_pooler(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> "Pooler":
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_embedding(pooler_config)
return DispatchPooler.for_embedding(pooler_config)
ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
@@ -292,7 +282,11 @@ def as_seq_cls_model(cls: _T) -> _T:
class ModelForSequenceClassification(
_create_pooling_model_cls(cls), SupportsCrossEncoding
):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
def _init_pooler(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> "Pooler":
text_config = vllm_config.model_config.hf_config.get_text_config()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
@@ -310,9 +304,7 @@ def as_seq_cls_model(cls: _T) -> _T:
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_seq_cls(
pooler_config, classifier=self.score
)
return DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
hf_config = self.config
@@ -424,7 +416,7 @@ def load_weights_using_from_2_way_softmax(
pooling_model_cls = next(
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
)
loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True)
loaded_weights = pooling_model_cls.load_weights(model, weights)
from vllm.tokenizers import get_tokenizer