TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)

This commit is contained in:
Zhuohan Li
2023-10-02 15:36:09 -07:00
committed by GitHub
parent 84e4e37d14
commit ba0bfd40e2
42 changed files with 819 additions and 1547 deletions

View File

@@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -73,16 +74,18 @@ class OPTAttention(nn.Module):
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.qkv_proj = ColumnParallelLinear(
embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling)
@@ -120,16 +123,18 @@ class OPTDecoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
gather_output=False,
perform_initialization=False)
self.fc2 = RowParallelLinear(config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
input_is_parallel=True,
perform_initialization=False)
self.fc1 = ColumnParallelLinear(
self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
gather_output=False,
)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
input_is_parallel=True,
)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
@@ -182,7 +187,7 @@ class OPTDecoder(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.word_embed_proj_dim,
perform_initialization=False)
)
# Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size)