[Model] use AutoWeightsLoader for baichuan, gpt-neox, mpt (#15939)

Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>
This commit is contained in:
Jonghyun Choe
2025-04-05 01:38:52 +09:00
committed by GitHub
parent a35a8a8392
commit bf7e3c51ae
3 changed files with 118 additions and 99 deletions

View File

@@ -27,7 +27,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.mpt import MPTConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -266,6 +266,23 @@ class MPTModel(nn.Module):
hidden_states = self.norm_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class MPTForCausalLM(nn.Module, SupportsPP):
@@ -318,17 +335,5 @@ class MPTForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)