[Model] Extend collect_children and no_init_weights contexts (#32757)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-22 16:20:27 +08:00
committed by GitHub
parent 1bf1a34b19
commit 2b8a38b6d6
20 changed files with 444 additions and 257 deletions

View File

@@ -41,10 +41,12 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces_base import default_pooling_type
from .utils import (
StageMissingLayer,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
no_init_weights,
)
@@ -413,10 +415,16 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
prefix: str = "",
model_type: type[InternLM2Model] = InternLM2Model,
):
super().__init__(vllm_config=vllm_config, prefix=prefix, model_type=model_type)
for attr in ("output", "logits_processor"):
delattr(self, attr)
with no_init_weights(
self,
lambda mod: StageMissingLayer("output", mod),
targets=(LogitsProcessor, ParallelLMHead),
):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
model_type=model_type,
)
config = vllm_config.model_config.hf_config
self.head_dtype = vllm_config.model_config.head_dtype