[Model] Support quantization of PixtralHFTransformer for PixtralHF (#9921)

Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Michael Goin
2024-11-05 13:42:20 -05:00
committed by GitHub
parent 731aec5be7
commit a53046b16f
2 changed files with 90 additions and 40 deletions

View File

@@ -299,3 +299,33 @@ def get_act_fn(
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
})
def get_act_and_mul_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn