[Model] use AutoWeightsLoader for BigCode, GPT-J (#16823)
Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>
This commit is contained in:
@@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
@@ -244,6 +244,30 @@ class GPTBigCodeModel(nn.Module):
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if ".attn.bias" in name:
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
|
||||
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
|
||||
weight_loader(param, loaded_weight, 'q')
|
||||
weight_loader(param, loaded_weight, 'k')
|
||||
weight_loader(param, loaded_weight, 'v')
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
||||
@@ -315,26 +339,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
if ".attn.bias" in name:
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
|
||||
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
|
||||
weight_loader(param, loaded_weight, 'q')
|
||||
weight_loader(param, loaded_weight, 'k')
|
||||
weight_loader(param, loaded_weight, 'v')
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
Reference in New Issue
Block a user