[Model] Support quantization of PixtralHFTransformer for PixtralHF (#9921)

Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Michael Goin
2024-11-05 13:42:20 -05:00
committed by GitHub
parent 731aec5be7
commit a53046b16f
2 changed files with 90 additions and 40 deletions

View File

@@ -299,3 +299,33 @@ def get_act_fn(
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
})
def get_act_and_mul_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn

View File

@@ -19,8 +19,11 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -798,20 +801,24 @@ class PixtralHFMLP(nn.Module):
super().__init__()
assert config.intermediate_size is not None
# TODO: Use quant_config and prefix after optimizing this
self.gate_proj = nn.Linear(config.hidden_size,
config.intermediate_size,
bias=False)
self.up_proj = nn.Linear(config.hidden_size,
config.intermediate_size,
bias=False)
self.down_proj = nn.Linear(config.intermediate_size,
config.hidden_size,
bias=False)
self.act = get_act_fn(config.hidden_act)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
output_sizes=[config.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
gate_up, _ = self.gate_up_proj(x)
x = self.act_and_mul(gate_up)
x, _ = self.down_proj(x)
return x
class PixtralHFAttention(nn.Module):
@@ -830,21 +837,21 @@ class PixtralHFAttention(nn.Module):
self.n_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim**-0.5
# TODO: Use quant_config and prefix after optimizing this
self.q_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.k_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.v_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.o_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size,
head_size=self.head_dim,
total_num_heads=self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=config.hidden_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def forward(
self,
@@ -854,13 +861,13 @@ class PixtralHFAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, patches, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
qkv_states, _ = self.qkv_proj(hidden_states)
q, k, v = qkv_states.chunk(3, dim=-1)
# Transpose q and k to apply HF's Rotary Position Embedding
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(batch, patches, self.n_heads, self.head_dim)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
@@ -868,22 +875,21 @@ class PixtralHFAttention(nn.Module):
# Transpose q and k back for attention
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
out = xops.memory_efficient_attention(q,
k,
v,
attn_bias=attention_mask)
else:
v = v.reshape(batch, patches, self.n_heads,
self.head_dim).transpose(1, 2)
v = v.transpose(1, 2)
out = nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask)
out = out.transpose(1, 2)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
out = out.view(batch, patches, self.n_heads * self.head_dim)
attn_output, _ = self.o_proj(out)
return self.o_proj(out)
return attn_output, None
class PixtralHFTransformerBlock(nn.Module):
@@ -912,9 +918,9 @@ class PixtralHFTransformerBlock(nn.Module):
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings)
r, _ = self.attention.forward(self.attention_norm(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings)
h = hidden_states + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
@@ -1053,10 +1059,24 @@ class PixtralHFVisionModel(nn.Module):
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = []
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
layer_count = len(self.transformer.layers)
for name, loaded_weight in weights:
# omit layers when num_hidden_layers_override is set
if name.startswith("transformer.layers"):
layer_idx = int(name.split(".")[2])
if layer_idx >= layer_count:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue