Add AWQ support for all models (#1714)
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
"""Custom activation functions."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import activation_ops
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
@@ -39,6 +42,27 @@ class FastGELU(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
This is used for some quantization methods like AWQ.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_module: nn.Module,
|
||||
hidden_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = act_module
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty(hidden_size, dtype=params_dtype, device="cuda"))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.act(x) / self.scales
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
"gelu_fast": FastGELU(),
|
||||
@@ -48,9 +72,27 @@ _ACTIVATION_REGISTRY = {
|
||||
}
|
||||
|
||||
|
||||
def get_act_fn(act_fn: str) -> nn.Module:
|
||||
def get_act_fn(
|
||||
act_fn_name: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn = act_fn.lower()
|
||||
if act_fn in _ACTIVATION_REGISTRY:
|
||||
return _ACTIVATION_REGISTRY[act_fn]
|
||||
raise ValueError(f"Activation function {act_fn!r} is not supported.")
|
||||
act_fn_name = act_fn_name.lower()
|
||||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
||||
if quant_config is not None:
|
||||
if act_fn_name in quant_config.get_scaled_act_names():
|
||||
if intermediate_size is None:
|
||||
raise ValueError(
|
||||
"intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(
|
||||
act_fn,
|
||||
intermediate_size,
|
||||
params_dtype=torch.get_default_dtype(),
|
||||
)
|
||||
return act_fn
|
||||
|
||||
@@ -63,6 +63,9 @@ class AWQConfig(QuantizationConfig):
|
||||
def get_linear_method(self) -> "AWQLinearMethod":
|
||||
return AWQLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
class AWQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ.
|
||||
|
||||
@@ -54,3 +54,11 @@ class QuantizationConfig(ABC):
|
||||
def get_linear_method(self) -> LinearMethodBase:
|
||||
"""Get the linear method to use for the quantized linear layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
"""Returns the activation function names that should be post-scaled.
|
||||
|
||||
For now, this is only used by AWQ.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -52,6 +52,9 @@ class SqueezeLLMConfig(QuantizationConfig):
|
||||
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
|
||||
return SqueezeLLMLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
"""Linear method for SqueezeLLM.
|
||||
|
||||
@@ -145,7 +145,8 @@ class BloomMLP(nn.Module):
|
||||
4 * hidden_size,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn("gelu")
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
@@ -154,7 +155,7 @@ class BloomMLP(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, _ = self.dense_h_to_4h(x)
|
||||
x = self.act(x)
|
||||
x = self.gelu_impl(x)
|
||||
x, _ = self.dense_4h_to_h(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from torch.nn import LayerNorm
|
||||
from transformers import FalconConfig as HF_FalconConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import (PagedAttention,
|
||||
PagedAttentionWithALiBi,
|
||||
PagedAttentionWithRoPE)
|
||||
@@ -131,6 +132,7 @@ class FalconAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
linear_method=linear_method,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
|
||||
self.use_rotary = config.rotary
|
||||
@@ -206,7 +208,8 @@ class FalconMLP(nn.Module):
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
linear_method=linear_method)
|
||||
self.act = nn.GELU()
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
or config.parallel_attn)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
|
||||
@@ -118,7 +118,9 @@ class GPT2MLP(nn.Module):
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
|
||||
@@ -137,7 +137,9 @@ class GPTBigMLP(nn.Module):
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
|
||||
@@ -128,7 +128,9 @@ class GPTJMLP(nn.Module):
|
||||
hidden_size,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc_in(hidden_states)
|
||||
|
||||
@@ -124,7 +124,9 @@ class GPTNeoXMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.act = get_act_fn(config.hidden_act, quant_config,
|
||||
config.intermediate_size)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||
|
||||
@@ -130,7 +130,8 @@ class MPTMLP(nn.Module):
|
||||
bias=not config.no_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn("gelu")
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.act = get_act_fn("gelu", quant_config, intermediate_size)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
|
||||
@@ -129,7 +129,9 @@ class OPTDecoderLayer(nn.Module):
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.activation_fn = get_act_fn(config.activation_function,
|
||||
quant_config, config.ffn_dim)
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
@@ -251,7 +253,7 @@ class OPTDecoder(nn.Module):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
pos_embeds = self.embed_positions(positions)
|
||||
if self.project_in is not None:
|
||||
inputs_embeds = self.project_in(inputs_embeds)
|
||||
inputs_embeds, _ = self.project_in(inputs_embeds)
|
||||
hidden_states = inputs_embeds + pos_embeds
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
@@ -266,7 +268,7 @@ class OPTDecoder(nn.Module):
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
hidden_states, _ = self.project_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
@@ -168,7 +168,9 @@ class PhiMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
n_inner)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user