TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models (#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](https://github.com/vllm-project/vllm/pull/1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
This commit is contained in:
Zhuohan Li
2023-11-15 22:50:41 -08:00
committed by GitHub
parent 660a7fcfa4
commit 7076fa1c9f
36 changed files with 2159 additions and 2508 deletions

View File

@@ -6,32 +6,28 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator,
load_tensor_parallel_weights,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
from vllm.model_executor.parallel_utils.layers import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.sequence import SequenceOutputs
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -39,7 +35,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GLMAttention(nn.Module):
def __init__(self, config):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -50,25 +50,33 @@ class GLMAttention(nn.Module):
self.total_num_kv_heads = (config.multi_query_group_num
if config.multi_query_attention else
config.num_attention_heads)
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.query_key_value = ColumnParallelLinear(
config.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
bias=config.add_qkv_bias,
gather_output=False,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
linear_method=linear_method,
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
input_is_parallel=True,
linear_method=linear_method,
)
self.attn = PagedAttentionWithRoPE(
@@ -78,7 +86,6 @@ class GLMAttention(nn.Module):
rotary_dim=self.head_dim // 2,
num_kv_heads=self.num_kv_heads,
is_neox_style=False,
# is_glm_style=True
)
def forward(
@@ -117,17 +124,21 @@ class GLMMLP(nn.Module):
state back into h hidden dimension.
"""
def __init__(self, config):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.add_bias = config.add_bias_linear
# Project to 4h.
self.dense_h_to_4h = ColumnParallelLinear(
self.dense_h_to_4h = MergedColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size * 2,
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
gather_output=False,
linear_method=linear_method,
)
self.activation_func = SiluAndMul()
@@ -137,7 +148,7 @@ class GLMMLP(nn.Module):
config.ffn_hidden_size,
config.hidden_size,
bias=config.add_bias_linear,
input_is_parallel=True,
linear_method=linear_method,
)
def forward(self, hidden_states):
@@ -159,6 +170,7 @@ class GLMBlock(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
@@ -172,7 +184,7 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config)
self.self_attention = GLMAttention(config, linear_method)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
@@ -180,7 +192,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
self.mlp = GLMMLP(config)
self.mlp = GLMMLP(config, linear_method)
def forward(
self,
@@ -227,7 +239,11 @@ class GLMBlock(nn.Module):
class GLMTransformer(nn.Module):
"""Transformer class."""
def __init__(self, config):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
@@ -236,7 +252,7 @@ class GLMTransformer(nn.Module):
# Transformer layers.
self.layers = nn.ModuleList(
[GLMBlock(config) for i in range(self.num_layers)])
[GLMBlock(config, linear_method) for i in range(self.num_layers)])
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
@@ -274,7 +290,11 @@ class GLMTransformer(nn.Module):
class ChatGLMModel(nn.Module):
def __init__(self, config):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
@@ -283,15 +303,10 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config)
self.encoder = GLMTransformer(config, linear_method)
self.output_layer = ColumnParallelLinear(
config.hidden_size,
config.padded_vocab_size,
bias=False,
gather_output=False,
params_dtype=config.torch_dtype,
)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
def forward(
self,
@@ -317,10 +332,15 @@ class ChatGLMModel(nn.Module):
class ChatGLMForCausalLM(nn.Module):
def __init__(self, config: ChatGLMConfig):
def __init__(
self,
config: ChatGLMConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
self.transformer = ChatGLMModel(config)
self.linear_method = linear_method
self.transformer = ChatGLMModel(config, linear_method)
self.lm_head_weight = self.transformer.output_layer.weight
self.sampler = Sampler(config.padded_vocab_size)
@@ -331,78 +351,26 @@ class ChatGLMForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"output_layer.weight",
"embedding.weight",
]
_row_parallel_weights = ["dense_4h_to_h", "self_attention.dense"]
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
q_proj_shard_size = self.config.hidden_size // tp_size
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.multi_query_group_num // tp_size)
mlp_hidden_shard_size = self.config.ffn_hidden_size // tp_size
state_dict = self.state_dict()
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
name = name.replace(".word_embeddings", "")
if name in state_dict:
param = state_dict[name]
if "query_key_value" in name:
q_offset = q_proj_shard_size * tp_rank
k_offset = (q_proj_shard_size * tp_size +
kv_proj_shard_size * tp_rank)
v_offset = (q_proj_shard_size * tp_size +
kv_proj_shard_size * (tp_size + tp_rank))
wq = loaded_weight[q_offset:q_offset + q_proj_shard_size]
wk = loaded_weight[k_offset:k_offset + kv_proj_shard_size]
wv = loaded_weight[v_offset:v_offset + kv_proj_shard_size]
loaded_weight = torch.cat([wq, wk, wv], dim=0)
param.data.copy_(loaded_weight)
continue
if "dense_h_to_4h" in name:
w_gate = loaded_weight[mlp_hidden_shard_size *
tp_rank:mlp_hidden_shard_size *
(tp_rank + 1)]
w_proj = loaded_weight[mlp_hidden_shard_size *
(tp_size +
tp_rank):mlp_hidden_shard_size *
(tp_size + tp_rank + 1)]
loaded_weight = torch.cat([w_gate, w_proj], dim=0)
param.data.copy_(loaded_weight)
continue
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
elif name == "transformer.rotary_pos_emb.inv_freq":
continue
else:
print("Warning never found tensor's name:", name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)