[Model] Extend collect_children and no_init_weights contexts (#32757)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -41,10 +41,12 @@ from vllm.sequence import IntermediateTensors
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .utils import (
|
||||
StageMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
no_init_weights,
|
||||
)
|
||||
|
||||
|
||||
@@ -413,10 +415,16 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
prefix: str = "",
|
||||
model_type: type[InternLM2Model] = InternLM2Model,
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, model_type=model_type)
|
||||
|
||||
for attr in ("output", "logits_processor"):
|
||||
delattr(self, attr)
|
||||
with no_init_weights(
|
||||
self,
|
||||
lambda mod: StageMissingLayer("output", mod),
|
||||
targets=(LogitsProcessor, ParallelLMHead),
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
Reference in New Issue
Block a user