[Model] Add AutoWeightsLoader support for jais (#38074)

Signed-off-by: grYe99 <guorongye99@gmail.com>
Co-authored-by: grYe99 <guorongye99@gmail.com>
This commit is contained in:
grYe99
2026-03-25 20:38:40 +08:00
committed by GitHub
parent d6bb2a9d9a
commit 7ac48fd357

View File

@@ -53,6 +53,7 @@ from vllm.transformers_utils.configs.jais import JAISConfig
from .interfaces import SupportsPP
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
@@ -311,6 +312,35 @@ class JAISModel(nn.Module):
hidden_states = self.ln_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if "relative_pe" in name:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class JAISLMHeadModel(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -364,36 +394,8 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if "relative_pe" in name:
continue
if not name.startswith("transformer."):
name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)