TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)
This commit is contained in:
@@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@@ -73,16 +74,18 @@ class OPTAttention(nn.Module):
|
||||
self.head_dim = embed_dim // total_num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.out_proj = RowParallelLinear(embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scaling)
|
||||
@@ -120,16 +123,18 @@ class OPTDecoderLayer(nn.Module):
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.fc1 = ColumnParallelLinear(self.embed_dim,
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.fc2 = RowParallelLinear(config.ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
self.embed_dim,
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
gather_output=False,
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
@@ -182,7 +187,7 @@ class OPTDecoder(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.word_embed_proj_dim,
|
||||
perform_initialization=False)
|
||||
)
|
||||
# Positional embeddings are replicated (not sharded).
|
||||
self.embed_positions = OPTLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
Reference in New Issue
Block a user