[Core] Allow disabling TP sharding for parallel Linear layer (#23024)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Isotr0py
2025-09-06 13:53:58 +08:00
committed by GitHub
parent 6432739ef1
commit 53b19ccdd5
7 changed files with 203 additions and 280 deletions

View File

@@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -435,12 +434,13 @@ class DeepseekV2MLAAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.fused_qkv_a_proj = MergedReplicatedLinear(
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj")
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True)
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,

View File

@@ -51,14 +51,10 @@ from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -174,20 +170,22 @@ class Glm4vVisionMLP(nn.Module):
use_data_parallel: bool = False,
):
super().__init__()
cls_gate_up = (MergedReplicatedLinear
if use_data_parallel else MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up(input_size=in_features,
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
cls_down = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel,
)
self.down_proj = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel,
)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor):
@@ -234,48 +232,32 @@ class Glm4vVisionAttention(nn.Module):
# Per attention head and per partition values.
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.tp_rank = (0 if use_data_parallel else
parallel_state.get_tensor_model_parallel_rank())
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
if use_data_parallel:
self.qkv = ReplicatedLinear(
input_size=embed_dim,
output_size=3 * projection_size,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj"
if quant_config else f"{prefix}.qkv",
)
self.proj = ReplicatedLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
)
else:
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj"
if quant_config else f"{prefix}.qkv",
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
)
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
disable_tp=use_data_parallel,
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
disable_tp=use_data_parallel,
)
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
@@ -494,41 +476,31 @@ class Glm4vPatchMerger(nn.Module):
) -> None:
super().__init__()
self.hidden_size = d_model
if use_data_parallel:
self.proj = ReplicatedLinear(
input_size=self.hidden_size,
output_size=self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
else:
self.proj = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=bias,
gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
self.proj = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=bias,
gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel,
)
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
cls_gate_up = (MergedReplicatedLinear
if use_data_parallel else MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up(
self.gate_up_proj = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[context_dim] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel,
)
cls_down = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down(
self.down_proj = RowParallelLinear(
context_dim,
self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel,
)
self.act_fn = SiluAndMul()
self.extra_activation_func = nn.GELU()

View File

@@ -48,7 +48,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
@@ -178,22 +177,20 @@ class Qwen2_5_VisionMLP(nn.Module):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else
MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up_proj(
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel)
cls_down_proj = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down_proj(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.down_proj = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel)
self.act_fn = act_fn
def forward(self, x: torch.Tensor):
@@ -243,30 +240,21 @@ class Qwen2_5_VisionAttention(nn.Module):
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
if use_data_parallel:
self.qkv = ReplicatedLinear(embed_dim,
self.hidden_size_per_attention_head *
3 * num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel)
else:
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
cls_proj = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.proj = cls_proj(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj")
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)

View File

@@ -21,7 +21,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@@ -667,35 +666,21 @@ class Step3VisionAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
if use_data_parallel:
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
3 * self.q_size,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
self.out_proj = ReplicatedLinear(
self.total_num_heads * self.head_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
self.out_proj = RowParallelLinear(self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
self.out_proj = RowParallelLinear(self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
@@ -740,20 +725,18 @@ class Step3VisionMLP(nn.Module):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.fc2 = cls_fc2(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel)
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)