[Misc] Add ignored layers for fp8 quantization (#6657)
This commit is contained in:
@@ -11,6 +11,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear, create_per_channel_scale_param)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@@ -18,14 +20,6 @@ from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Note: this is a hack. We should update each model to register the
|
||||
# stacked params and get it from there instead in a future PR.
|
||||
# fused_name: List[shard_name]
|
||||
_FUSED_LAYER_NAME_MAPPING = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
|
||||
class FBGEMMFp8Config(QuantizationConfig):
|
||||
"""Config class for FBGEMM Fp8."""
|
||||
@@ -62,37 +56,10 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
|
||||
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
|
||||
|
||||
def _is_layer_skipped(self, prefix: str) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in _FUSED_LAYER_NAME_MAPPING:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = shard_prefix in self.ignore_list
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
elif is_shard_skipped != is_skipped:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision.")
|
||||
else:
|
||||
is_skipped = prefix in self.ignore_list
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if self._is_layer_skipped(prefix):
|
||||
if is_layer_skipped(prefix, self.ignore_list):
|
||||
return UnquantizedLinearMethod()
|
||||
return FBGEMMFp8LinearMethod(self)
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user