[Core] Allow disabling TP sharding for parallel Linear layer (#23024)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
MergedReplicatedLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -435,12 +434,13 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.fused_qkv_a_proj = MergedReplicatedLinear(
|
||||
self.fused_qkv_a_proj = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_qkv_a_proj")
|
||||
prefix=f"{prefix}.fused_qkv_a_proj",
|
||||
disable_tp=True)
|
||||
else:
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
|
||||
@@ -51,14 +51,10 @@ from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
MergedReplicatedLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@@ -174,20 +170,22 @@ class Glm4vVisionMLP(nn.Module):
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
cls_gate_up = (MergedReplicatedLinear
|
||||
if use_data_parallel else MergedColumnParallelLinear)
|
||||
self.gate_up_proj = cls_gate_up(input_size=in_features,
|
||||
output_sizes=[hidden_features] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
cls_down = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.down_proj = cls_down(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=in_features,
|
||||
output_sizes=[hidden_features] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
@@ -234,48 +232,32 @@ class Glm4vVisionAttention(nn.Module):
|
||||
# Per attention head and per partition values.
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.tp_rank = (0 if use_data_parallel else
|
||||
parallel_state.get_tensor_model_parallel_rank())
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, self.tp_size)
|
||||
|
||||
if use_data_parallel:
|
||||
self.qkv = ReplicatedLinear(
|
||||
input_size=embed_dim,
|
||||
output_size=3 * projection_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
|
||||
prefix=f"{prefix}.qkv_proj"
|
||||
if quant_config else f"{prefix}.qkv",
|
||||
)
|
||||
self.proj = ReplicatedLinear(
|
||||
input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
|
||||
prefix=f"{prefix}.qkv_proj"
|
||||
if quant_config else f"{prefix}.qkv",
|
||||
)
|
||||
self.proj = RowParallelLinear(
|
||||
input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
bias=False,
|
||||
)
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
|
||||
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.proj = RowParallelLinear(
|
||||
input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
bias=False,
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
@@ -494,41 +476,31 @@ class Glm4vPatchMerger(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = d_model
|
||||
if use_data_parallel:
|
||||
self.proj = ReplicatedLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_size=self.hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
)
|
||||
else:
|
||||
self.proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=bias,
|
||||
gather_output=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
)
|
||||
self.proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=bias,
|
||||
gather_output=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
|
||||
cls_gate_up = (MergedReplicatedLinear
|
||||
if use_data_parallel else MergedColumnParallelLinear)
|
||||
self.gate_up_proj = cls_gate_up(
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=self.hidden_size,
|
||||
output_sizes=[context_dim] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
cls_down = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.down_proj = cls_down(
|
||||
self.down_proj = RowParallelLinear(
|
||||
context_dim,
|
||||
self.hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
self.extra_activation_func = nn.GELU()
|
||||
|
||||
@@ -48,7 +48,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
MergedReplicatedLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
@@ -178,22 +177,20 @@ class Qwen2_5_VisionMLP(nn.Module):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else
|
||||
MergedColumnParallelLinear)
|
||||
self.gate_up_proj = cls_gate_up_proj(
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=in_features,
|
||||
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
cls_down_proj = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.down_proj = cls_down_proj(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
self.down_proj = RowParallelLinear(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
disable_tp=use_data_parallel)
|
||||
self.act_fn = act_fn
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
@@ -243,30 +240,21 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, self.tp_size)
|
||||
|
||||
if use_data_parallel:
|
||||
self.qkv = ReplicatedLinear(embed_dim,
|
||||
self.hidden_size_per_attention_head *
|
||||
3 * num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv")
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
else:
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv")
|
||||
|
||||
cls_proj = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.proj = cls_proj(input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj")
|
||||
self.proj = RowParallelLinear(input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
|
||||
@@ -21,7 +21,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@@ -667,35 +666,21 @@ class Step3VisionAttention(nn.Module):
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
|
||||
if use_data_parallel:
|
||||
self.qkv_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
3 * self.q_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.out_proj = ReplicatedLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
@@ -740,20 +725,18 @@ class Step3VisionMLP(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
self.fc1 = cls_fc1(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
cls_fc2 = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.fc2 = cls_fc2(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
disable_tp=use_data_parallel)
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user