TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models (#1622)
Refactor the tensor parallelism, quantization, and weight-loading codes. Summary of the new features enabled by this PR: - **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](https://github.com/vllm-project/vllm/pull/1580). - Model loading code became much simpler. - Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
This commit is contained in:
@@ -30,18 +30,20 @@ from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
|
||||
PagedAttentionWithALiBi)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||
|
||||
@@ -80,20 +82,17 @@ class BaiChuanMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
2 * intermediate_size,
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
linear_method=linear_method)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@@ -116,6 +115,7 @@ class BaiChuanAttention(nn.Module):
|
||||
position_embedding: str,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -131,17 +131,19 @@ class BaiChuanAttention(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.W_pack = ColumnParallelLinear(
|
||||
self.W_pack = QKVParallelLinear(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.postion_embedding == "ALIBI":
|
||||
@@ -188,7 +190,10 @@ class BaiChuanAttention(nn.Module):
|
||||
|
||||
class BaiChuanDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||
def __init__(self,
|
||||
config: BaiChuanConfig,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@@ -200,11 +205,13 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
position_embedding=position_embedding,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.mlp = BaiChuanMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@@ -241,7 +248,10 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
|
||||
class BaiChuanModel(nn.Module):
|
||||
|
||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||
def __init__(self,
|
||||
config: BaiChuanConfig,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -252,7 +262,7 @@ class BaiChuanModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
BaiChuanDecoderLayer(config, position_embedding)
|
||||
BaiChuanDecoderLayer(config, position_embedding, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -285,16 +295,15 @@ class BaiChuanModel(nn.Module):
|
||||
|
||||
class BaiChuanBaseForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config, position_embedding: str):
|
||||
def __init__(self,
|
||||
config,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = BaiChuanModel(config, position_embedding)
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.linear_method = linear_method
|
||||
self.model = BaiChuanModel(config, position_embedding, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@@ -311,79 +320,46 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = []
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
if "W_pack" in name:
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tp_world_size
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tp_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank,
|
||||
)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config, "ALIBI")
|
||||
def __init__(self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__(config, "ALIBI", linear_method)
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config, "ROPE")
|
||||
def __init__(self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__(config, "ROPE", linear_method)
|
||||
|
||||
Reference in New Issue
Block a user