[VLM] Add TP support for Phi-4-MM (#14453)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-03-08 21:57:14 +08:00
committed by GitHub
parent cb8bdfade2
commit 03fe18ae0f
4 changed files with 50 additions and 295 deletions

View File

@@ -6,69 +6,26 @@
#!/usr/bin/env python3
import abc
import math
from functools import partial
from typing import Callable, Dict, List, Literal, Optional, Union
from typing import List, Literal, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl, CheckpointWrapper, checkpoint_wrapper, offload_wrapper)
CheckpointWrapper)
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel)
from torch.utils.checkpoint import checkpoint
from transformers import PretrainedConfig
from vllm.model_executor.models.phi4mm_utils import (
AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer,
MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias,
adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper,
get_offset, repeat, unfold_tensor, validate_checkpointing_config)
MultiHeadedAttention, MultiSequential, NemoConvSubsampling,
T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor)
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
def encoder_checkpoint_wrapper(
activation_checkpointing: Union[str, Dict],
layer_cls: type,
idx: int = 0,
) -> Callable:
"""return encoder activation checkpoint wrapper"""
validate_checkpointing_config(activation_checkpointing)
if isinstance(activation_checkpointing, str):
if activation_checkpointing:
if activation_checkpointing == "offload":
return offload_wrapper
return partial(checkpoint_wrapper)
return lambda x: x
if isinstance(activation_checkpointing, dict):
target_layer_cls = activation_checkpointing.get(
"module", "transformer")
if target_layer_cls.lower() == "transformer":
target_layer_cls = (
"EncoderLayer",
"ConformerEncoderLayer",
)
elif target_layer_cls.lower() == "attention":
target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention")
checkpointing_interval = activation_checkpointing.get("interval", 1)
offloading = activation_checkpointing.get("offload", False)
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
"reentrant", True) else CheckpointImpl.NO_REENTRANT)
if (idx % checkpointing_interval == 0
and layer_cls.__name__ in target_layer_cls):
if offloading:
return offload_wrapper
return partial(checkpoint_wrapper, checkpoint_impl=impl)
return lambda x: x
raise ValueError("Invalid activation_checkpointing config")
class ConformerEncoderLayer(nn.Module):
"""ConformerEncoder Layer module.
for more details see conformer paper:
@@ -208,10 +165,7 @@ class ConformerEncoderLayer(nn.Module):
bias_in_glu=bias_in_glu,
)
self.self_attn = encoder_checkpoint_wrapper(
activation_checkpointing,
MultiHeadedAttention,
)(MultiHeadedAttention(
self.self_attn = MultiHeadedAttention(
n_head,
d_model,
dropout_rate,
@@ -221,7 +175,7 @@ class ConformerEncoderLayer(nn.Module):
use_pt_scaled_dot_product_attention=
use_pt_scaled_dot_product_attention,
group_size=attn_group_sizes,
))
)
self.conv = ConvModule(
d_model,
ext_pw_out_channel,
@@ -441,26 +395,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
else:
raise NotImplementedError
def post_init(self, init_model_config):
pretrained_speech_encoder_path = init_model_config.get(
"pretrained_speech_encoder_path", None)
if pretrained_speech_encoder_path:
model_state = torch.load(pretrained_speech_encoder_path,
map_location="cpu")
encoder_state_dict = {}
for k, v in model_state.items():
if "encoder." in k:
tmp_k = k.replace("encoder.", "")
encoder_state_dict[tmp_k] = v
if hasattr(self, "encoder_embedding"):
del self.encoder_embedding
self.load_state_dict(encoder_state_dict)
if not hasattr(self, "encoder_embedding"):
self.encoder_embedding = MeanVarianceNormLayer(
self.encoder_embedding_config["input_size"])
self.encoder_embedding = MeanVarianceNormLayer(
self.encoder_embedding_config["input_size"])
def compute_lens_change(self, feature_lens):
"""feature_lens: int
@@ -558,14 +494,6 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
# Create mask matrix for streaming
# S stores start index. if chunksize is 18, s is [0,18,36,....]
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
# avoid randomness when run evaluation or decoding
if self.training and np.random.rand() > 0.5:
# Either first or last chunk is not complete.
# If only the last one is not complete, EOS is not effective
chunk_start_idx = seq_len - chunk_start_idx
chunk_start_idx = chunk_start_idx[::-1]
chunk_start_idx = chunk_start_idx[:-1]
chunk_start_idx = np.insert(chunk_start_idx, 0, 0)
enc_streaming_mask = (adaptive_enc_mask(
seq_len, chunk_start_idx,
@@ -883,23 +811,17 @@ class ConformerEncoder(TransformerEncoderBase):
self.num_blocks = num_blocks
self.num_lang = num_lang
self.kernel_size = kernel_size
self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(
self.embed)
self.replication_pad_for_subsample_embedding: bool = (
replication_pad_for_subsample_embedding)
assert (self.num_heads % attention_group_size == 0
), "attention_group_size must divide n_head"
self.num_heads_k = self.num_heads // attention_group_size
self.encoders = repeat(
num_blocks,
lambda i: encoder_checkpoint_wrapper(activation_checkpointing,
ConformerEncoderLayer, i)
(ConformerEncoderLayer(
self.encoders = MultiSequential(*[
ConformerEncoderLayer(
d_model=attention_dim,
ext_pw_out_channel=ext_pw_out_channel,
depthwise_seperable_out_channel=
depthwise_seperable_out_channel,
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
depthwise_multiplier=depthwise_multiplier,
n_head=attention_heads,
d_ffn=linear_units,
@@ -916,14 +838,13 @@ class ConformerEncoder(TransformerEncoderBase):
bias_in_glu=bias_in_glu,
linear_glu_in_convm=linear_glu_in_convm,
attention_glu_type=attention_glu_type,
activation_checkpointing=attn_checkpointing(
activation_checkpointing, i),
activation_checkpointing=activation_checkpointing,
export=export,
use_pt_scaled_dot_product_attention=
use_pt_scaled_dot_product_attention,
attn_group_sizes=attention_group_size,
)),
)
) for _ in range(num_blocks)
])
self.extra_layer_output_idx = extra_layer_output_idx
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
# Make a zeros scalar we can use in get_initial_state to determine
@@ -1041,9 +962,6 @@ class ConformerEncoder(TransformerEncoderBase):
return input_tensor, masks # , layer_emb
def gradient_checkpointing_enable(self):
pass
class WindowQformer(nn.Module):
"""Window-level Qformer"""
@@ -1077,13 +995,6 @@ class WindowQformer(nn.Module):
self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12)
if normalize_before else None)
self.window_size = window_size
self.gradient_checkpointing_enable = False
def enable_gradient_checkpointing(self):
self.gradient_checkpointing_enable = True
def disable_gradient_checkpointing(self):
self.gradient_checkpointing_enable = False
def forward(self, audio_embed, mask, embed_len=None):
"""forward decoder"""
@@ -1111,20 +1022,10 @@ class WindowQformer(nn.Module):
# NT' x 1 x D
q = self.queries.expand(bsz * slen, -1, -1)
for layer in self.decoders:
if self.gradient_checkpointing_enable and self.training:
q = checkpoint(
layer.__call__,
q,
embed_chunk,
None,
mask,
use_reentrant=True,
)
else:
q = layer(tgt=q,
memory=embed_chunk,
tgt_mask=None,
memory_mask=mask)
q = layer(tgt=q,
memory=embed_chunk,
tgt_mask=None,
memory_mask=mask)
if self.after_norm is not None:
q = self.after_norm(q)
@@ -1147,13 +1048,6 @@ class AudioEmbedding(nn.Module):
hidden_size = (config.n_embd
if hasattr(config, "n_embd") else config.hidden_size)
if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"):
embd_drop = (config.embd_pdrop if hasattr(config, "embd_pdrop")
else config.embed_pdrop)
self.drop = nn.Dropout(embd_drop)
else:
self.drop = None
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
audio_dim_out = (
@@ -1167,12 +1061,6 @@ class AudioEmbedding(nn.Module):
assert encoder_config is not None
self.encoder = ConformerEncoder(**encoder_config)
# fake initialization, create encoder_embedding layer only so that
# in decoding, all parameters can be loaded in
# from_pretrained_function in training, we do post init after
# from_pretrained function to make sure the correct initialization
self.encoder.post_init({})
audio_dim_out = encoder_config["attention_dim"]
n_mels = encoder_config["input_size"]
else:
@@ -1221,14 +1109,6 @@ class AudioEmbedding(nn.Module):
else:
self.conv_ds = None
enable_gradient_checkpointing = kwargs.get(
"enable_gradient_checkpointing", False)
if enable_gradient_checkpointing:
self.encoder.gradient_checkpointing_enable()
if self.qformer:
self.qformer.enable_gradient_checkpointing()
projection_cls = kwargs.get("projection_cls", "linear")
if projection_cls == "linear":
self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
@@ -1388,16 +1268,4 @@ class AudioEmbedding(nn.Module):
hidden_states.dtype).to(hidden_states.device))
idx += cnt
else:
if self.training:
# hidden_states[:, 0:img_set_tensor.shape[0]] =
# hidden_states[:, 0:img_set_tensor.shape[0]] +
# 0 * img_set_tensor.to(hidden_states.dtype)
# .to(hidden_states.device)
hidden_states[:, 0:1] = hidden_states[:, 0:1] + \
0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype)\
.to(hidden_states.device)
if self.drop is not None:
hidden_states = self.drop(hidden_states)
return hidden_states