TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models (#1622)
Refactor the tensor parallelism, quantization, and weight-loading codes. Summary of the new features enabled by this PR: - **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](https://github.com/vllm-project/vllm/pull/1580). - Model loading code became much simpler. - Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
This commit is contained in:
@@ -30,14 +30,18 @@ from transformers import OPTConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@@ -63,6 +67,7 @@ class OPTAttention(nn.Module):
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@@ -74,17 +79,18 @@ class OPTAttention(nn.Module):
|
||||
self.head_dim = embed_dim // total_num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
self.head_dim,
|
||||
total_num_heads,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
@@ -108,7 +114,11 @@ class OPTAttention(nn.Module):
|
||||
|
||||
class OPTDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -116,6 +126,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
bias=config.enable_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
@@ -127,13 +138,13 @@ class OPTDecoderLayer(nn.Module):
|
||||
self.embed_dim,
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
@@ -177,7 +188,11 @@ class OPTDecoderLayer(nn.Module):
|
||||
|
||||
class OPTDecoder(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -194,16 +209,18 @@ class OPTDecoder(nn.Module):
|
||||
|
||||
# Project out & in will be replicated if they exist.
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_out = nn.Linear(config.hidden_size,
|
||||
config.word_embed_proj_dim,
|
||||
bias=False)
|
||||
self.project_out = ReplicatedLinear(config.hidden_size,
|
||||
config.word_embed_proj_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
else:
|
||||
self.project_out = None
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_in = nn.Linear(config.word_embed_proj_dim,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
@@ -218,8 +235,10 @@ class OPTDecoder(nn.Module):
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layers = nn.ModuleList([
|
||||
OPTDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -253,9 +272,13 @@ class OPTDecoder(nn.Module):
|
||||
|
||||
class OPTModel(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder = OPTDecoder(config)
|
||||
self.decoder = OPTDecoder(config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -271,12 +294,15 @@ class OPTModel(nn.Module):
|
||||
|
||||
class OPTForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = OPTModel(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.linear_method = linear_method
|
||||
self.model = OPTModel(config, linear_method)
|
||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@@ -294,48 +320,31 @@ class OPTForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight", "fc1.weight", "fc1.bias"
|
||||
]
|
||||
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
|
||||
if name.startswith("decoder."):
|
||||
name = "model." + name
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user