[Misc] Add ignored layers for fp8 quantization (#6657)
This commit is contained in:
@@ -1,10 +1,48 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
SUPPORTED_NUM_BITS = [4, 8]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# Note: this is a hack. We should update each model to register the
|
||||
# stacked params and get it from there instead in a future PR.
|
||||
# fused_name: List[shard_name]
|
||||
FUSED_LAYER_NAME_MAPPING = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
|
||||
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = shard_prefix in ignored_layers
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
elif is_shard_skipped != is_skipped:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision.")
|
||||
else:
|
||||
is_skipped = prefix in ignored_layers
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
|
||||
Reference in New Issue
Block a user