[Model] Refactor Phi-4-multimodal to use merged processor and support V1 (#15477)

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Isotr0py
2025-04-19 17:26:11 +08:00
committed by GitHub
parent d9737ca1c6
commit 83f3c3bd91
15 changed files with 818 additions and 1246 deletions

View File

@@ -1159,8 +1159,11 @@ class AudioEmbedding(nn.Module):
input_embeds: torch.FloatTensor,
audio_attention_mask: torch.Tensor = None,
audio_projection_mode: str = "speech",
):
) -> torch.FloatTensor:
"""
arguments:
input_embeds: audio features (B, T, D) B: num audios in a sequence
"""
if self.freeze_audio_processor:
with torch.no_grad():
audio_features, masks = self.encoder(input_embeds,
@@ -1210,62 +1213,20 @@ class AudioEmbedding(nn.Module):
def forward(
self,
input_ids: torch.LongTensor,
input_embeds: torch.FloatTensor,
audio_embed_sizes,
**kwargs,
audio_features: torch.FloatTensor,
audio_attention_mask: torch.Tensor = None,
audio_projection_mode: str = "speech",
) -> torch.FloatTensor:
"""
arguments:
input_ids: input text ids (B, U)
input_embeds: audio features (B, T, D) B: num audios in a sequence
audio_features: audio features (T, D)
returns:
audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
"""
assert input_embeds is not None and len(input_embeds) == len(
audio_embed_sizes)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
with torch.no_grad():
positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(
as_tuple=False)
if not isinstance(input_embeds, list):
input_embeds = [input_embeds]
audio_projection_mode = kwargs.get("audio_projection_mode", "speech")
audio_set_tensor = [
self.get_audio_features(
input_embed, audio_projection_mode=audio_projection_mode)
for input_embed in input_embeds
]
with torch.no_grad():
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
if "wte" in kwargs:
# we use the token embedding layer from the huggingface model, this
# is REQUIRED to make sure we are using the loaded weights.
hidden_states = kwargs["wte"](input_ids)
else:
# otherwise, we use token embedding in pretrained mixformer from
# phi team
hidden_states = self.wte(input_ids)
if len(positions.tolist()) > 0:
assert sum(audio_embed_sizes) == len(
positions
), "please ensure the encoder outputs have the same length as"\
" defined in input_ids!"
idx = 0
for i in range(len(audio_embed_sizes)):
cnt = audio_embed_sizes[i]
assert audio_set_tensor[i].shape[0] == 1
hidden_states[
positions[idx, 0],
positions[idx, 1]:positions[idx, 1] + cnt,
] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to(
hidden_states.dtype).to(hidden_states.device))
idx += cnt
return hidden_states
audio_embeds = self.get_audio_features(
audio_features.unsqueeze(0),
audio_attention_mask=audio_attention_mask,
audio_projection_mode=audio_projection_mode,
)
return audio_embeds.squeeze(0)