[VLM] Migrate remain DP-supported ViT models to use disable_tp (#24363)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-09-12 02:30:41 +08:00
committed by GitHub
parent 361ae27f8a
commit bb2b5126da
3 changed files with 54 additions and 106 deletions

View File

@@ -31,7 +31,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -139,37 +138,22 @@ class Idefics2VisionAttention(nn.Module):
assert self.num_heads % tp_size == 0
self.num_heads_per_partition = self.num_heads // tp_size
if use_data_parallel:
self.q_size = self.num_heads * self.head_dim
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
3 * self.q_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = ReplicatedLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)
# Use unified MultiHeadAttention with Flash Attention support
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
@@ -201,23 +185,21 @@ class Idefics2VisionMLP(nn.Module):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
)
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.fc2 = cls_fc2(
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -389,30 +371,6 @@ class Idefics2VisionTransformer(nn.Module):
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
def _consolidate_qkv_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
qkv_idx_mappings = {
".self_attn.q_proj": 0,
".self_attn.k_proj": 1,
".self_attn.v_proj": 2,
}
qkv_weights = {}
for name, loaded_weight in weights:
for weight_name, idx in qkv_idx_mappings.items():
if weight_name not in name:
continue
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
if new_name not in qkv_weights:
qkv_weights[new_name] = [None] * 3
qkv_weights[new_name][idx] = loaded_weight
break
else:
yield name, loaded_weight
for key, weight in qkv_weights.items():
qkv_weight = torch.cat(weight, dim=0)
yield key, qkv_weight
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@@ -425,9 +383,6 @@ class Idefics2VisionTransformer(nn.Module):
loaded_params: set[str] = set()
layer_count = len(self.encoder.layers)
if self.use_data_parallel:
weights = self._consolidate_qkv_weights(weights)
for name, loaded_weight in weights:
# skip pooling header
if name.startswith("head."):