[Model][Quant] Fix GLM, Fix fused module mappings for quantization (#12634)

Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Kyle Sayers
2025-02-05 00:32:06 -05:00
committed by GitHub
parent 686006a220
commit 7ff7a638b6
12 changed files with 194 additions and 150 deletions

View File

@@ -265,12 +265,14 @@ class GLMAttention(nn.Module):
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
@@ -327,6 +329,7 @@ class GLMMLP(nn.Module):
self,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
@@ -338,6 +341,7 @@ class GLMMLP(nn.Module):
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.activation_func = SiluAndMul()
@@ -348,6 +352,7 @@ class GLMMLP(nn.Module):
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
def forward(self, hidden_states):
@@ -396,7 +401,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
self.mlp = GLMMLP(config, quant_config)
self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,
@@ -507,7 +512,8 @@ class ChatGLMModel(nn.Module):
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.embedding")
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
@@ -766,6 +772,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal):
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
@@ -777,9 +784,18 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config
# Initialize VL
if hasattr(config, "vision_config"):
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
if hasattr(config, "vision_config"): # noqa: SIM108
instance_cls = ChatGLMV
# Initialize LLM
else:
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
instance_cls = ChatGLM
# quant_config references base class members,
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)