[Model][Quant] Fix GLM, Fix fused module mappings for quantization (#12634)
Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from typing import List, Optional, Tuple
|
||||
from types import MappingProxyType
|
||||
from typing import List, Mapping, Optional, Tuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -12,14 +13,6 @@ from vllm.scalar_type import ScalarType, scalar_types
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# Note: this is a hack. We should update each model to register the
|
||||
# stacked params and get it from there instead in a future PR.
|
||||
# fused_name: List[shard_name]
|
||||
FUSED_LAYER_NAME_MAPPING = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
|
||||
@@ -178,14 +171,23 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
return res.permute(inv_perm)
|
||||
|
||||
|
||||
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
||||
def is_layer_skipped(
|
||||
prefix: str,
|
||||
ignored_layers: List[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
|
||||
Reference in New Issue
Block a user