[Model][Quant] Fix GLM, Fix fused module mappings for quantization (#12634)

Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Kyle Sayers
2025-02-05 00:32:06 -05:00
committed by GitHub
parent 686006a220
commit 7ff7a638b6
12 changed files with 194 additions and 150 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""This file is used for /tests and /benchmarks"""
from typing import List, Optional, Tuple
from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple
import numpy
import torch
@@ -12,14 +13,6 @@ from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
@@ -178,14 +171,23 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
return res.permute(inv_perm)
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
def is_layer_skipped(
prefix: str,
ignored_layers: List[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None