[Model][Quant] Fix GLM, Fix fused module mappings for quantization (#12634)
Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@@ -265,12 +265,14 @@ class GLMAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||
@@ -327,6 +329,7 @@ class GLMMLP(nn.Module):
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -338,6 +341,7 @@ class GLMMLP(nn.Module):
|
||||
[config.ffn_hidden_size] * 2,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_h_to_4h",
|
||||
)
|
||||
|
||||
self.activation_func = SiluAndMul()
|
||||
@@ -348,6 +352,7 @@ class GLMMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_4h_to_h",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@@ -396,7 +401,7 @@ class GLMBlock(nn.Module):
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
# MLP
|
||||
self.mlp = GLMMLP(config, quant_config)
|
||||
self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -507,7 +512,8 @@ class ChatGLMModel(nn.Module):
|
||||
|
||||
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embedding")
|
||||
|
||||
self.num_layers = config.num_layers
|
||||
self.multi_query_group_num = config.multi_query_group_num
|
||||
@@ -766,6 +772,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
# These will be updated when an instance class is selected
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
@@ -777,9 +784,18 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
# Initialize VL
|
||||
if hasattr(config, "vision_config"):
|
||||
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
|
||||
if hasattr(config, "vision_config"): # noqa: SIM108
|
||||
instance_cls = ChatGLMV
|
||||
# Initialize LLM
|
||||
else:
|
||||
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
|
||||
instance_cls = ChatGLM
|
||||
|
||||
# quant_config references base class members,
|
||||
# so update values before init is called
|
||||
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
|
||||
cls.supported_lora_modules += instance_cls.supported_lora_modules
|
||||
cls.embedding_modules.update(instance_cls.embedding_modules)
|
||||
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
|
||||
return instance_cls(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
Reference in New Issue
Block a user