[VLM] Migrate remain DP-supported ViT models to use disable_tp (#24363)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -31,7 +31,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -139,37 +138,22 @@ class Idefics2VisionAttention(nn.Module):
|
||||
assert self.num_heads % tp_size == 0
|
||||
self.num_heads_per_partition = self.num_heads // tp_size
|
||||
|
||||
if use_data_parallel:
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.qkv_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
3 * self.q_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
# Use unified MultiHeadAttention with Flash Attention support
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
@@ -201,23 +185,21 @@ class Idefics2VisionMLP(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
self.fc1 = cls_fc1(
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
cls_fc2 = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.fc2 = cls_fc2(
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -389,30 +371,6 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||
return last_hidden_state
|
||||
|
||||
def _consolidate_qkv_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
qkv_idx_mappings = {
|
||||
".self_attn.q_proj": 0,
|
||||
".self_attn.k_proj": 1,
|
||||
".self_attn.v_proj": 2,
|
||||
}
|
||||
qkv_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
for weight_name, idx in qkv_idx_mappings.items():
|
||||
if weight_name not in name:
|
||||
continue
|
||||
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
|
||||
if new_name not in qkv_weights:
|
||||
qkv_weights[new_name] = [None] * 3
|
||||
qkv_weights[new_name][idx] = loaded_weight
|
||||
break
|
||||
else:
|
||||
yield name, loaded_weight
|
||||
for key, weight in qkv_weights.items():
|
||||
qkv_weight = torch.cat(weight, dim=0)
|
||||
yield key, qkv_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@@ -425,9 +383,6 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
layer_count = len(self.encoder.layers)
|
||||
|
||||
if self.use_data_parallel:
|
||||
weights = self._consolidate_qkv_weights(weights)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# skip pooling header
|
||||
if name.startswith("head."):
|
||||
|
||||
Reference in New Issue
Block a user