TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models (#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](https://github.com/vllm-project/vllm/pull/1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
This commit is contained in:
Zhuohan Li
2023-11-15 22:50:41 -08:00
committed by GitHub
parent 660a7fcfa4
commit 7076fa1c9f
36 changed files with 2159 additions and 2508 deletions

View File

@@ -31,15 +31,17 @@ from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -47,7 +49,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTBigCodeAttention(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
@@ -61,32 +67,26 @@ class GPTBigCodeAttention(nn.Module):
self.multi_query = config.multi_query
if self.multi_query:
total_num_kv_heads = 1
self.num_kv_heads = 1
self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
)
self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim,
bias=True)
else:
total_num_kv_heads = total_num_heads
self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.c_attn = ColumnParallelLinear(
self.hidden_size,
self.hidden_size + 2 * self.kv_dim,
bias=True,
gather_output=False,
)
self.kv_dim = self.head_dim * self.num_kv_heads
self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
total_num_kv_heads,
bias=True,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
linear_method=linear_method,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
@@ -100,17 +100,14 @@ class GPTBigCodeAttention(nn.Module):
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
if self.multi_query:
q, _ = self.c_attn_q(hidden_states)
kv = self.c_attn_kv(hidden_states)
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
else:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split([
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
[
self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim, self.kv_dim
],
dim=-1)
dim=-1,
)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
@@ -124,6 +121,7 @@ class GPTBigMLP(nn.Module):
self,
intermediate_size: int,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
@@ -131,13 +129,13 @@ class GPTBigMLP(nn.Module):
hidden_size,
intermediate_size,
bias=True,
gather_output=False,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
linear_method=linear_method,
)
self.act = get_act_fn(config.activation_function)
@@ -150,16 +148,20 @@ class GPTBigMLP(nn.Module):
class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config)
self.attn = GPTBigCodeAttention(config, linear_method)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config)
self.mlp = GPTBigMLP(inner_dim, config, linear_method)
def forward(
self,
@@ -189,23 +191,23 @@ class GPTBigCodeBlock(nn.Module):
class GPTBigCodeModel(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
assert not config.add_cross_attention
self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
# to 50304 in order to make it divisible by 64.
# This improves performance since GPUs are faster if the dimension
# is divisible by 64. In addition, it allows us to shard the embedding
# layer across 2, 4, 8, or more GPUs.
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList(
[GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)])
self.h = nn.ModuleList([
GPTBigCodeBlock(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
@@ -235,12 +237,15 @@ class GPTBigCodeModel(nn.Module):
class GPTBigCodeForCausalLM(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
def __init__(
self,
config: GPTBigCodeConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.transformer = GPTBigCodeModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.linear_method = linear_method
self.transformer = GPTBigCodeModel(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
@@ -258,89 +263,21 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
total_num_kv_heads = (1 if self.config.multi_query else
total_num_heads)
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
total_kv_size = head_size * total_num_kv_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
wq, wk, wv = torch.split(
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
dim=0)
wq = wq[head_size * head_start:head_size * head_end]
if not self.config.multi_query:
# Split the heads when using normal multi-head attention
wk = wk[head_size * head_start:head_size * head_end]
wv = wv[head_size * head_start:head_size * head_end]
loaded_weight = torch.cat([wq, wk, wv], dim=0)
else:
# For multi-query attention, we split the query
# but replicate the key and value.
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("c_attn", "c_attn_q")
kv_weight_name = name.replace("c_attn", "c_attn_kv")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
continue
param = state_dict[name]
if name == "transformer.wte.weight":
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)