[Core] Allow disabling TP sharding for parallel Linear layer (#23024)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -48,7 +48,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
MergedReplicatedLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
@@ -178,22 +177,20 @@ class Qwen2_5_VisionMLP(nn.Module):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else
|
||||
MergedColumnParallelLinear)
|
||||
self.gate_up_proj = cls_gate_up_proj(
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=in_features,
|
||||
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
cls_down_proj = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.down_proj = cls_down_proj(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
self.down_proj = RowParallelLinear(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
disable_tp=use_data_parallel)
|
||||
self.act_fn = act_fn
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
@@ -243,30 +240,21 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, self.tp_size)
|
||||
|
||||
if use_data_parallel:
|
||||
self.qkv = ReplicatedLinear(embed_dim,
|
||||
self.hidden_size_per_attention_head *
|
||||
3 * num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv")
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
else:
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv")
|
||||
|
||||
cls_proj = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.proj = cls_proj(input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj")
|
||||
self.proj = RowParallelLinear(input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
|
||||
Reference in New Issue
Block a user