[Model] Refactor Phi-4-multimodal to use merged processor and support V1 (#15477)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1159,8 +1159,11 @@ class AudioEmbedding(nn.Module):
|
||||
input_embeds: torch.FloatTensor,
|
||||
audio_attention_mask: torch.Tensor = None,
|
||||
audio_projection_mode: str = "speech",
|
||||
):
|
||||
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
arguments:
|
||||
input_embeds: audio features (B, T, D) B: num audios in a sequence
|
||||
"""
|
||||
if self.freeze_audio_processor:
|
||||
with torch.no_grad():
|
||||
audio_features, masks = self.encoder(input_embeds,
|
||||
@@ -1210,62 +1213,20 @@ class AudioEmbedding(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
input_embeds: torch.FloatTensor,
|
||||
audio_embed_sizes,
|
||||
**kwargs,
|
||||
audio_features: torch.FloatTensor,
|
||||
audio_attention_mask: torch.Tensor = None,
|
||||
audio_projection_mode: str = "speech",
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
arguments:
|
||||
input_ids: input text ids (B, U)
|
||||
input_embeds: audio features (B, T, D) B: num audios in a sequence
|
||||
audio_features: audio features (T, D)
|
||||
|
||||
returns:
|
||||
audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
|
||||
"""
|
||||
assert input_embeds is not None and len(input_embeds) == len(
|
||||
audio_embed_sizes)
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
with torch.no_grad():
|
||||
positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(
|
||||
as_tuple=False)
|
||||
|
||||
if not isinstance(input_embeds, list):
|
||||
input_embeds = [input_embeds]
|
||||
|
||||
audio_projection_mode = kwargs.get("audio_projection_mode", "speech")
|
||||
audio_set_tensor = [
|
||||
self.get_audio_features(
|
||||
input_embed, audio_projection_mode=audio_projection_mode)
|
||||
for input_embed in input_embeds
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
|
||||
|
||||
if "wte" in kwargs:
|
||||
# we use the token embedding layer from the huggingface model, this
|
||||
# is REQUIRED to make sure we are using the loaded weights.
|
||||
hidden_states = kwargs["wte"](input_ids)
|
||||
else:
|
||||
# otherwise, we use token embedding in pretrained mixformer from
|
||||
# phi team
|
||||
hidden_states = self.wte(input_ids)
|
||||
|
||||
if len(positions.tolist()) > 0:
|
||||
assert sum(audio_embed_sizes) == len(
|
||||
positions
|
||||
), "please ensure the encoder outputs have the same length as"\
|
||||
" defined in input_ids!"
|
||||
idx = 0
|
||||
for i in range(len(audio_embed_sizes)):
|
||||
cnt = audio_embed_sizes[i]
|
||||
assert audio_set_tensor[i].shape[0] == 1
|
||||
hidden_states[
|
||||
positions[idx, 0],
|
||||
positions[idx, 1]:positions[idx, 1] + cnt,
|
||||
] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to(
|
||||
hidden_states.dtype).to(hidden_states.device))
|
||||
idx += cnt
|
||||
|
||||
return hidden_states
|
||||
audio_embeds = self.get_audio_features(
|
||||
audio_features.unsqueeze(0),
|
||||
audio_attention_mask=audio_attention_mask,
|
||||
audio_projection_mode=audio_projection_mode,
|
||||
)
|
||||
return audio_embeds.squeeze(0)
|
||||
|
||||
Reference in New Issue
Block a user