[Model] use AutoWeightsLoader for stablelm,starcoder2,zamba2 (#16103)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
rongfu.leng
2025-04-06 20:52:01 +08:00
committed by GitHub
parent c2a9671510
commit 242a637aea
3 changed files with 135 additions and 121 deletions

View File

@@ -39,7 +39,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
from .utils import maybe_prefix
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
class Zamba2LoRA(nn.Module):
@@ -777,6 +777,37 @@ class Zamba2Model(nn.Module):
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for chkpt_weight_name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in chkpt_weight_name:
continue
chkpt_weight_name = chkpt_weight_name.replace(
weight_name, param_name)
param = params_dict[chkpt_weight_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if chkpt_weight_name not in params_dict:
continue
param = params_dict[chkpt_weight_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(chkpt_weight_name)
return loaded_params
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
"""Zamba2 model with causal language modeling head.
@@ -787,6 +818,12 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
- Support for model parallelism and quantization
- Sampling capabilities for text generation
"""
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
"A_log": "A",
"0.weight": "A.weight",
"1.weight": "B.weight",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
"""Initialize the Zamba2 model for causal language modeling.
@@ -992,40 +1029,5 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
weights_dict = {}
for key, loaded_weight in weights:
if "A_log" in key:
key = key.replace("A_log", "A")
elif "adapter_list" in key:
key = key.replace("0.weight", "A.weight")
key = key.replace("1.weight", "B.weight")
weights_dict[key] = loaded_weight
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for chkpt_weight_name, loaded_weight in weights_dict.items():
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in chkpt_weight_name:
continue
chkpt_weight_name = chkpt_weight_name.replace(
weight_name, param_name)
param = params_dict[chkpt_weight_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if chkpt_weight_name not in params_dict:
continue
param = params_dict[chkpt_weight_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(chkpt_weight_name)
return loaded_params
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)