[Misc] Add ignored layers for fp8 quantization (#6657)

This commit is contained in:
Michael Goin
2024-07-23 14:04:04 -04:00
committed by GitHub
parent 38c4b7e863
commit 0eb0757bef
4 changed files with 57 additions and 47 deletions

View File

@@ -1,10 +1,48 @@
"""This file is used for /tests and /benchmarks"""
from typing import List
import numpy
import torch
SUPPORTED_NUM_BITS = [4, 8]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in ignored_layers
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in ignored_layers
assert is_skipped is not None
return is_skipped
def get_pack_factor(num_bits):
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"