[VLM] Add TP support for Phi-4-MM (#14453)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -6,69 +6,26 @@
|
||||
#!/usr/bin/env python3
|
||||
import abc
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Literal, Optional, Union
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
CheckpointImpl, CheckpointWrapper, checkpoint_wrapper, offload_wrapper)
|
||||
CheckpointWrapper)
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullyShardedDataParallel)
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.models.phi4mm_utils import (
|
||||
AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer,
|
||||
MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias,
|
||||
adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper,
|
||||
get_offset, repeat, unfold_tensor, validate_checkpointing_config)
|
||||
MultiHeadedAttention, MultiSequential, NemoConvSubsampling,
|
||||
T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor)
|
||||
|
||||
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
|
||||
|
||||
|
||||
def encoder_checkpoint_wrapper(
|
||||
activation_checkpointing: Union[str, Dict],
|
||||
layer_cls: type,
|
||||
idx: int = 0,
|
||||
) -> Callable:
|
||||
"""return encoder activation checkpoint wrapper"""
|
||||
validate_checkpointing_config(activation_checkpointing)
|
||||
|
||||
if isinstance(activation_checkpointing, str):
|
||||
if activation_checkpointing:
|
||||
if activation_checkpointing == "offload":
|
||||
return offload_wrapper
|
||||
return partial(checkpoint_wrapper)
|
||||
return lambda x: x
|
||||
|
||||
if isinstance(activation_checkpointing, dict):
|
||||
target_layer_cls = activation_checkpointing.get(
|
||||
"module", "transformer")
|
||||
if target_layer_cls.lower() == "transformer":
|
||||
target_layer_cls = (
|
||||
"EncoderLayer",
|
||||
"ConformerEncoderLayer",
|
||||
)
|
||||
elif target_layer_cls.lower() == "attention":
|
||||
target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention")
|
||||
checkpointing_interval = activation_checkpointing.get("interval", 1)
|
||||
offloading = activation_checkpointing.get("offload", False)
|
||||
impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
|
||||
"reentrant", True) else CheckpointImpl.NO_REENTRANT)
|
||||
|
||||
if (idx % checkpointing_interval == 0
|
||||
and layer_cls.__name__ in target_layer_cls):
|
||||
if offloading:
|
||||
return offload_wrapper
|
||||
return partial(checkpoint_wrapper, checkpoint_impl=impl)
|
||||
return lambda x: x
|
||||
|
||||
raise ValueError("Invalid activation_checkpointing config")
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""ConformerEncoder Layer module.
|
||||
for more details see conformer paper:
|
||||
@@ -208,10 +165,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
bias_in_glu=bias_in_glu,
|
||||
)
|
||||
|
||||
self.self_attn = encoder_checkpoint_wrapper(
|
||||
activation_checkpointing,
|
||||
MultiHeadedAttention,
|
||||
)(MultiHeadedAttention(
|
||||
self.self_attn = MultiHeadedAttention(
|
||||
n_head,
|
||||
d_model,
|
||||
dropout_rate,
|
||||
@@ -221,7 +175,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
use_pt_scaled_dot_product_attention=
|
||||
use_pt_scaled_dot_product_attention,
|
||||
group_size=attn_group_sizes,
|
||||
))
|
||||
)
|
||||
self.conv = ConvModule(
|
||||
d_model,
|
||||
ext_pw_out_channel,
|
||||
@@ -441,26 +395,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def post_init(self, init_model_config):
|
||||
|
||||
pretrained_speech_encoder_path = init_model_config.get(
|
||||
"pretrained_speech_encoder_path", None)
|
||||
if pretrained_speech_encoder_path:
|
||||
model_state = torch.load(pretrained_speech_encoder_path,
|
||||
map_location="cpu")
|
||||
encoder_state_dict = {}
|
||||
for k, v in model_state.items():
|
||||
if "encoder." in k:
|
||||
tmp_k = k.replace("encoder.", "")
|
||||
encoder_state_dict[tmp_k] = v
|
||||
|
||||
if hasattr(self, "encoder_embedding"):
|
||||
del self.encoder_embedding
|
||||
self.load_state_dict(encoder_state_dict)
|
||||
|
||||
if not hasattr(self, "encoder_embedding"):
|
||||
self.encoder_embedding = MeanVarianceNormLayer(
|
||||
self.encoder_embedding_config["input_size"])
|
||||
self.encoder_embedding = MeanVarianceNormLayer(
|
||||
self.encoder_embedding_config["input_size"])
|
||||
|
||||
def compute_lens_change(self, feature_lens):
|
||||
"""feature_lens: int
|
||||
@@ -558,14 +494,6 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
# Create mask matrix for streaming
|
||||
# S stores start index. if chunksize is 18, s is [0,18,36,....]
|
||||
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
|
||||
# avoid randomness when run evaluation or decoding
|
||||
if self.training and np.random.rand() > 0.5:
|
||||
# Either first or last chunk is not complete.
|
||||
# If only the last one is not complete, EOS is not effective
|
||||
chunk_start_idx = seq_len - chunk_start_idx
|
||||
chunk_start_idx = chunk_start_idx[::-1]
|
||||
chunk_start_idx = chunk_start_idx[:-1]
|
||||
chunk_start_idx = np.insert(chunk_start_idx, 0, 0)
|
||||
|
||||
enc_streaming_mask = (adaptive_enc_mask(
|
||||
seq_len, chunk_start_idx,
|
||||
@@ -883,23 +811,17 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
self.num_blocks = num_blocks
|
||||
self.num_lang = num_lang
|
||||
self.kernel_size = kernel_size
|
||||
self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(
|
||||
self.embed)
|
||||
self.replication_pad_for_subsample_embedding: bool = (
|
||||
replication_pad_for_subsample_embedding)
|
||||
assert (self.num_heads % attention_group_size == 0
|
||||
), "attention_group_size must divide n_head"
|
||||
self.num_heads_k = self.num_heads // attention_group_size
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda i: encoder_checkpoint_wrapper(activation_checkpointing,
|
||||
ConformerEncoderLayer, i)
|
||||
(ConformerEncoderLayer(
|
||||
self.encoders = MultiSequential(*[
|
||||
ConformerEncoderLayer(
|
||||
d_model=attention_dim,
|
||||
ext_pw_out_channel=ext_pw_out_channel,
|
||||
depthwise_seperable_out_channel=
|
||||
depthwise_seperable_out_channel,
|
||||
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
|
||||
depthwise_multiplier=depthwise_multiplier,
|
||||
n_head=attention_heads,
|
||||
d_ffn=linear_units,
|
||||
@@ -916,14 +838,13 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
bias_in_glu=bias_in_glu,
|
||||
linear_glu_in_convm=linear_glu_in_convm,
|
||||
attention_glu_type=attention_glu_type,
|
||||
activation_checkpointing=attn_checkpointing(
|
||||
activation_checkpointing, i),
|
||||
activation_checkpointing=activation_checkpointing,
|
||||
export=export,
|
||||
use_pt_scaled_dot_product_attention=
|
||||
use_pt_scaled_dot_product_attention,
|
||||
attn_group_sizes=attention_group_size,
|
||||
)),
|
||||
)
|
||||
) for _ in range(num_blocks)
|
||||
])
|
||||
self.extra_layer_output_idx = extra_layer_output_idx
|
||||
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
|
||||
# Make a zeros scalar we can use in get_initial_state to determine
|
||||
@@ -1041,9 +962,6 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
|
||||
return input_tensor, masks # , layer_emb
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
pass
|
||||
|
||||
|
||||
class WindowQformer(nn.Module):
|
||||
"""Window-level Qformer"""
|
||||
@@ -1077,13 +995,6 @@ class WindowQformer(nn.Module):
|
||||
self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12)
|
||||
if normalize_before else None)
|
||||
self.window_size = window_size
|
||||
self.gradient_checkpointing_enable = False
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing_enable = True
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing_enable = False
|
||||
|
||||
def forward(self, audio_embed, mask, embed_len=None):
|
||||
"""forward decoder"""
|
||||
@@ -1111,20 +1022,10 @@ class WindowQformer(nn.Module):
|
||||
# NT' x 1 x D
|
||||
q = self.queries.expand(bsz * slen, -1, -1)
|
||||
for layer in self.decoders:
|
||||
if self.gradient_checkpointing_enable and self.training:
|
||||
q = checkpoint(
|
||||
layer.__call__,
|
||||
q,
|
||||
embed_chunk,
|
||||
None,
|
||||
mask,
|
||||
use_reentrant=True,
|
||||
)
|
||||
else:
|
||||
q = layer(tgt=q,
|
||||
memory=embed_chunk,
|
||||
tgt_mask=None,
|
||||
memory_mask=mask)
|
||||
q = layer(tgt=q,
|
||||
memory=embed_chunk,
|
||||
tgt_mask=None,
|
||||
memory_mask=mask)
|
||||
|
||||
if self.after_norm is not None:
|
||||
q = self.after_norm(q)
|
||||
@@ -1147,13 +1048,6 @@ class AudioEmbedding(nn.Module):
|
||||
hidden_size = (config.n_embd
|
||||
if hasattr(config, "n_embd") else config.hidden_size)
|
||||
|
||||
if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"):
|
||||
embd_drop = (config.embd_pdrop if hasattr(config, "embd_pdrop")
|
||||
else config.embed_pdrop)
|
||||
self.drop = nn.Dropout(embd_drop)
|
||||
else:
|
||||
self.drop = None
|
||||
|
||||
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
|
||||
|
||||
audio_dim_out = (
|
||||
@@ -1167,12 +1061,6 @@ class AudioEmbedding(nn.Module):
|
||||
assert encoder_config is not None
|
||||
self.encoder = ConformerEncoder(**encoder_config)
|
||||
|
||||
# fake initialization, create encoder_embedding layer only so that
|
||||
# in decoding, all parameters can be loaded in
|
||||
# from_pretrained_function in training, we do post init after
|
||||
# from_pretrained function to make sure the correct initialization
|
||||
self.encoder.post_init({})
|
||||
|
||||
audio_dim_out = encoder_config["attention_dim"]
|
||||
n_mels = encoder_config["input_size"]
|
||||
else:
|
||||
@@ -1221,14 +1109,6 @@ class AudioEmbedding(nn.Module):
|
||||
else:
|
||||
self.conv_ds = None
|
||||
|
||||
enable_gradient_checkpointing = kwargs.get(
|
||||
"enable_gradient_checkpointing", False)
|
||||
if enable_gradient_checkpointing:
|
||||
self.encoder.gradient_checkpointing_enable()
|
||||
|
||||
if self.qformer:
|
||||
self.qformer.enable_gradient_checkpointing()
|
||||
|
||||
projection_cls = kwargs.get("projection_cls", "linear")
|
||||
if projection_cls == "linear":
|
||||
self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
|
||||
@@ -1388,16 +1268,4 @@ class AudioEmbedding(nn.Module):
|
||||
hidden_states.dtype).to(hidden_states.device))
|
||||
idx += cnt
|
||||
|
||||
else:
|
||||
if self.training:
|
||||
# hidden_states[:, 0:img_set_tensor.shape[0]] =
|
||||
# hidden_states[:, 0:img_set_tensor.shape[0]] +
|
||||
# 0 * img_set_tensor.to(hidden_states.dtype)
|
||||
# .to(hidden_states.device)
|
||||
hidden_states[:, 0:1] = hidden_states[:, 0:1] + \
|
||||
0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype)\
|
||||
.to(hidden_states.device)
|
||||
|
||||
if self.drop is not None:
|
||||
hidden_states = self.drop(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user