[FEAT] [Performance] Enable DP for ViT in Qwen2.5VL (#22742)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
TJian
2025-08-19 08:25:57 -07:00
committed by GitHub
parent 4d9c61993a
commit 1298c67795
5 changed files with 633 additions and 48 deletions

View File

@@ -45,10 +45,14 @@ from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
@@ -57,6 +61,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
@@ -170,19 +175,25 @@ class Qwen2_5_VisionMLP(nn.Module):
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else
MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up_proj(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
cls_down_proj = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down_proj(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.act_fn = act_fn
def forward(self, x: torch.Tensor):
@@ -220,28 +231,42 @@ class Qwen2_5_VisionAttention(nn.Module):
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_size = (1 if use_data_parallel else
parallel_state.get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj")
if use_data_parallel:
self.qkv = ReplicatedLinear(embed_dim,
self.hidden_size_per_attention_head *
3 * num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
else:
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
cls_proj = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.proj = cls_proj(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj")
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
@@ -302,8 +327,6 @@ class Qwen2_5_VisionAttention(nn.Module):
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self.is_flash_attn_backend:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
@@ -370,23 +393,27 @@ class Qwen2_5_VisionBlock(nn.Module):
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.attn = Qwen2_5_VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel)
self.mlp = Qwen2_5_VisionMLP(dim,
mlp_hidden_dim,
act_fn=act_fn,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
def forward(
self,
@@ -445,24 +472,30 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.ln_q = norm_layer(context_dim)
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.mlp = nn.ModuleList([
ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.0"),
cls_fc1(self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.0"),
nn.GELU(),
RowParallelLinear(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.2"),
cls_fc2(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.2"),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -514,6 +547,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
@@ -523,6 +557,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.use_data_parallel = use_data_parallel
self.out_hidden_size = vision_config.out_hidden_size
# args for get_window_index_thw
self.window_size = vision_config.window_size
@@ -550,7 +586,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}")
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(depth)
])
self.merger = Qwen2_5_VisionPatchMerger(
@@ -560,6 +597,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
@@ -767,7 +805,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
@@ -840,6 +877,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = (vllm_config.parallel_config.
enable_multimodal_encoder_data_parallel)
self.config = config
self.multimodal_config = multimodal_config
@@ -851,6 +890,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
quant_config=self._maybe_ignore_quant_config(
self.quant_config),
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
else:
self.visual = None
@@ -973,7 +1013,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"]
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list)
else:
image_embeds = self.visual(pixel_values,
grid_thw=grid_thw_list)
# Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
@@ -995,8 +1041,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"]
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw_list)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list)
else:
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw_list)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size