diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index b8c1310cb..cc0c1aa01 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -53,6 +53,7 @@ from vllm.transformers_utils.configs.jais import JAISConfig from .interfaces import SupportsPP from .utils import ( + AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -311,6 +312,35 @@ class JAISModel(nn.Module): hidden_states = self.ln_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if "relative_pe" in name: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class JAISLMHeadModel(nn.Module, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -364,36 +394,8 @@ class JAISLMHeadModel(nn.Module, SupportsPP): return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "lm_head.weight" in name: - # GPT-2 ties the weights of the embedding layer and the final - # linear layer. - continue - if ".attn.bias" in name or ".attn.masked_bias" in name: - # Skip attention mask. - # NOTE: "c_attn.bias" should not be skipped. - continue - if "relative_pe" in name: - continue - if not name.startswith("transformer."): - name = "transformer." + name - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - # The HF's GPT-2 implementation uses Conv1D instead of Linear. - # Because of this, we need to transpose the weights. - # Note(zhuohan): the logic below might break quantized models. - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: - if conv1d_weight_name not in name: - continue - if not name.endswith(".weight"): - continue - loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights)