[Model] use AutoWeightsLoader for stablelm,starcoder2,zamba2 (#16103)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
@@ -39,7 +39,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
|
||||
from .utils import maybe_prefix
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
class Zamba2LoRA(nn.Module):
|
||||
@@ -777,6 +777,37 @@ class Zamba2Model(nn.Module):
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for chkpt_weight_name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in chkpt_weight_name:
|
||||
continue
|
||||
chkpt_weight_name = chkpt_weight_name.replace(
|
||||
weight_name, param_name)
|
||||
param = params_dict[chkpt_weight_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if chkpt_weight_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[chkpt_weight_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(chkpt_weight_name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
"""Zamba2 model with causal language modeling head.
|
||||
@@ -787,6 +818,12 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
- Support for model parallelism and quantization
|
||||
- Sampling capabilities for text generation
|
||||
"""
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
|
||||
"A_log": "A",
|
||||
"0.weight": "A.weight",
|
||||
"1.weight": "B.weight",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
"""Initialize the Zamba2 model for causal language modeling.
|
||||
@@ -992,40 +1029,5 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
weights_dict = {}
|
||||
for key, loaded_weight in weights:
|
||||
if "A_log" in key:
|
||||
key = key.replace("A_log", "A")
|
||||
elif "adapter_list" in key:
|
||||
key = key.replace("0.weight", "A.weight")
|
||||
key = key.replace("1.weight", "B.weight")
|
||||
weights_dict[key] = loaded_weight
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for chkpt_weight_name, loaded_weight in weights_dict.items():
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in chkpt_weight_name:
|
||||
continue
|
||||
chkpt_weight_name = chkpt_weight_name.replace(
|
||||
weight_name, param_name)
|
||||
param = params_dict[chkpt_weight_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if chkpt_weight_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[chkpt_weight_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(chkpt_weight_name)
|
||||
return loaded_params
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
Reference in New Issue
Block a user