TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)

This commit is contained in:
Zhuohan Li
2023-10-02 15:36:09 -07:00
committed by GitHub
parent 84e4e37d14
commit ba0bfd40e2
42 changed files with 819 additions and 1547 deletions

View File

@@ -15,8 +15,9 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
@@ -53,7 +54,6 @@ class MPTAttention(nn.Module):
3 * self.d_model,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False,
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
@@ -63,7 +63,6 @@ class MPTAttention(nn.Module):
self.d_model,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False,
)
tp_world_size = get_tensor_model_parallel_world_size()
@@ -113,17 +112,19 @@ class MPTMLP(nn.Module):
hidden_size = config.d_model
expansion_ratio = config.expansion_ratio
intermediate_size = expansion_ratio * hidden_size
self.up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False)
self.up_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=not config.no_bias,
gather_output=False,
)
self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=not config.no_bias,
input_is_parallel=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.up_proj(x)
@@ -172,9 +173,10 @@ class MPTModel(nn.Module):
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"
self.wte = VocabParallelEmbedding(config.vocab_size,
config.d_model,
perform_initialization=False)
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
)
self.blocks = nn.ModuleList(
[MPTBlock(config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)