TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)
This commit is contained in:
@@ -28,7 +28,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
from vllm.model_executor.parallel_utils.layers import (
|
||||
VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
@@ -53,14 +53,12 @@ class QWenMLP(nn.Module):
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
@@ -98,14 +96,12 @@ class QWenAttention(nn.Module):
|
||||
3 * hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
@@ -190,9 +186,10 @@ class QWenModel(nn.Module):
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.wte = VocabParallelEmbedding(vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.wte = VocabParallelEmbedding(
|
||||
vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.h = nn.ModuleList(
|
||||
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
@@ -235,7 +232,6 @@ class QWenLMHeadModel(nn.Module):
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user