[Misc] Add BNB support to GLM4-V model (#12184)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -41,7 +41,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@@ -605,9 +605,50 @@ class ChatGLMModel(nn.Module):
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
|
||||
("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if "rotary_pos_emb.inv_freq" in name:
|
||||
continue
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={".word_embeddings": ""}, )
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -660,52 +701,9 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
|
||||
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
|
||||
"transformer.vision.linear_proj.merged_proj.weight": {
|
||||
"transformer.vision.linear_proj.gate_proj.weight": None,
|
||||
"transformer.vision.linear_proj.dense_h_to_4h.weight": None,
|
||||
}
|
||||
}
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
is_weight_to_be_merge = False
|
||||
for _, merged_weight_dict in merged_weights_dict.items():
|
||||
if name in merged_weight_dict:
|
||||
assert merged_weight_dict[name] is None
|
||||
merged_weight_dict[name] = loaded_weight
|
||||
is_weight_to_be_merge = True
|
||||
if is_weight_to_be_merge:
|
||||
continue
|
||||
if "rotary_pos_emb.inv_freq" in name:
|
||||
continue
|
||||
if "word_embeddings" in name:
|
||||
name = name.replace(".word_embeddings", "")
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
for combined_name, merged_weight_dict in merged_weights_dict.items():
|
||||
if combined_name in params_dict:
|
||||
param = params_dict[combined_name]
|
||||
combined_weight = torch.cat(list(merged_weight_dict.values()),
|
||||
dim=0)
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, combined_weight)
|
||||
loaded_params.add(combined_name)
|
||||
return loaded_params
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
class ChatGLM(ChatGLMBaseModel):
|
||||
@@ -726,6 +724,7 @@ class ChatGLM(ChatGLMBaseModel):
|
||||
|
||||
|
||||
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"],
|
||||
@@ -777,7 +776,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
) -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
# Initialize VL
|
||||
if hasattr(config, "visual"):
|
||||
if hasattr(config, "vision_config"):
|
||||
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
|
||||
# Initialize LLM
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user