[Core] Allow disabling TP sharding for parallel Linear layer (#23024)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Isotr0py
2025-09-06 13:53:58 +08:00
committed by GitHub
parent 6432739ef1
commit 53b19ccdd5
7 changed files with 203 additions and 280 deletions

View File

@@ -48,7 +48,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
@@ -178,22 +177,20 @@ class Qwen2_5_VisionMLP(nn.Module):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else
MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up_proj(
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel)
cls_down_proj = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down_proj(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.down_proj = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel)
self.act_fn = act_fn
def forward(self, x: torch.Tensor):
@@ -243,30 +240,21 @@ class Qwen2_5_VisionAttention(nn.Module):
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
if use_data_parallel:
self.qkv = ReplicatedLinear(embed_dim,
self.hidden_size_per_attention_head *
3 * num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel)
else:
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
cls_proj = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.proj = cls_proj(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj")
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)