TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models (#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](https://github.com/vllm-project/vllm/pull/1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
This commit is contained in:
Zhuohan Li
2023-11-15 22:50:41 -08:00
committed by GitHub
parent 660a7fcfa4
commit 7076fa1c9f
36 changed files with 2159 additions and 2508 deletions

View File

@@ -15,24 +15,19 @@ from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor,
hf_model_weights_iterator,
load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.layers import (
VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear,
)
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.qwen import QWenConfig
@@ -46,20 +41,17 @@ class QWenMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str = "silu",
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(
hidden_size,
2 * intermediate_size,
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
gather_output=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
)
linear_method=linear_method)
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -74,12 +66,15 @@ class QWenMLP(nn.Module):
class QWenAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None):
def __init__(
self,
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
@@ -90,18 +85,18 @@ class QWenAttention(nn.Module):
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
# pylint: disable=invalid-name
self.c_attn = ColumnParallelLinear(
self.c_attn = QKVParallelLinear(
hidden_size,
3 * hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
gather_output=False,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method,
)
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(
@@ -134,7 +129,11 @@ class QWenAttention(nn.Module):
class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig):
def __init__(
self,
config: QWenConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -144,11 +143,14 @@ class QWenBlock(nn.Module):
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling)
rope_scaling=rope_scaling,
linear_method=linear_method)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)
self.mlp = QWenMLP(config.hidden_size,
config.intermediate_size // 2,
linear_method=linear_method)
def forward(
self,
@@ -180,18 +182,23 @@ class QWenBlock(nn.Module):
class QWenModel(nn.Module):
def __init__(self, config: QWenConfig):
def __init__(
self,
config: QWenConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(
vocab_size,
config.vocab_size,
config.hidden_size,
)
self.h = nn.ModuleList(
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
self.h = nn.ModuleList([
QWenBlock(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
@@ -222,17 +229,16 @@ class QWenModel(nn.Module):
class QWenLMHeadModel(nn.Module):
def __init__(self, config: QWenConfig):
def __init__(
self,
config: QWenConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.transformer = QWenModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
)
self.linear_method = linear_method
self.transformer = QWenModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
@@ -249,75 +255,30 @@ class QWenLMHeadModel(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = []
_row_parallel_weights = ["c_proj.weight"]
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if "c_attn" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
if "weight" in name:
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif "bias" in name:
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["w2", "w1"]):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
if is_gate_up_weight:
continue
param = state_dict[name]
if "wte" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)