[Bugfix][Model] Fix audio-in-video support for Qwen2.5-Omni and Qwen3-Omni (#33605)
Signed-off-by: linyueqian <linyueqian@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -113,6 +113,95 @@ except (ImportError, ModuleNotFoundError):
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def check_interleaved_audio_video(
|
||||
is_video: torch.Tensor,
|
||||
is_audio: torch.Tensor,
|
||||
num_video: int,
|
||||
num_audio: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if video and audio positions are interleaved in the multimodal region.
|
||||
|
||||
Returns:
|
||||
True if video and audio tokens are interleaved, False otherwise.
|
||||
"""
|
||||
if num_video == 0 or num_audio == 0:
|
||||
return False
|
||||
|
||||
video_pos = is_video.nonzero(as_tuple=True)[0]
|
||||
audio_pos = is_audio.nonzero(as_tuple=True)[0]
|
||||
|
||||
return (
|
||||
video_pos[0].item() < audio_pos[-1].item()
|
||||
and audio_pos[0].item() < video_pos[-1].item()
|
||||
)
|
||||
|
||||
|
||||
def merge_interleaved_embeddings(
|
||||
inputs_embeds: torch.Tensor,
|
||||
multimodal_embeddings: "MultiModalEmbeddings",
|
||||
is_video: torch.Tensor,
|
||||
is_audio: torch.Tensor,
|
||||
is_multimodal: torch.Tensor,
|
||||
num_video: int,
|
||||
num_audio: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Merge embeddings for interleaved audio-in-video sequences.
|
||||
|
||||
When use_audio_in_video=True, video and audio tokens are interleaved in
|
||||
the token sequence, but embeddings are provided as separate contiguous
|
||||
tensors (video first, then audio). This function reorders video and audio
|
||||
embeddings to match sequence position order and scatters them efficiently.
|
||||
|
||||
Args:
|
||||
inputs_embeds: The input embeddings tensor to merge into.
|
||||
multimodal_embeddings: List of embedding tensors (video, audio, other).
|
||||
is_video: Boolean mask for video token positions.
|
||||
is_audio: Boolean mask for audio token positions.
|
||||
is_multimodal: Boolean mask for all multimodal token positions.
|
||||
num_video: Total count of video tokens.
|
||||
num_audio: Total count of audio tokens.
|
||||
|
||||
Returns:
|
||||
The merged inputs_embeds tensor with multimodal embeddings scattered
|
||||
to their correct positions.
|
||||
"""
|
||||
# Categorize embeddings by modality based on token counts.
|
||||
# Embeddings come grouped by modality but order varies (e.g., image, video, audio
|
||||
# or video, audio depending on input kwargs order).
|
||||
video_embeds: list[torch.Tensor] = []
|
||||
audio_embeds: list[torch.Tensor] = []
|
||||
other_embeds: list[torch.Tensor] = []
|
||||
video_remaining = num_video
|
||||
audio_remaining = num_audio
|
||||
|
||||
for emb in multimodal_embeddings:
|
||||
n = emb.shape[0]
|
||||
if video_remaining > 0 and n <= video_remaining:
|
||||
video_embeds.append(emb)
|
||||
video_remaining -= n
|
||||
elif audio_remaining > 0 and n <= audio_remaining:
|
||||
audio_embeds.append(emb)
|
||||
audio_remaining -= n
|
||||
else:
|
||||
other_embeds.append(emb)
|
||||
|
||||
# Scatter each modality to its positions
|
||||
if video_embeds:
|
||||
video_positions = is_video.nonzero(as_tuple=True)[0]
|
||||
inputs_embeds[video_positions] = torch.cat(video_embeds, dim=0)
|
||||
if audio_embeds:
|
||||
audio_positions = is_audio.nonzero(as_tuple=True)[0]
|
||||
inputs_embeds[audio_positions] = torch.cat(audio_embeds, dim=0)
|
||||
if other_embeds:
|
||||
other_mask = is_multimodal & ~is_video & ~is_audio
|
||||
other_positions = other_mask.nonzero(as_tuple=True)[0]
|
||||
inputs_embeds[other_positions] = torch.cat(other_embeds, dim=0)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
@@ -1286,17 +1375,48 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# This is to satisfy the type checker for each overload
|
||||
from .utils import _merge_multimodal_embeddings
|
||||
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
return super().embed_input_ids(
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
self.get_language_model().embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
# Check for audio-in-video: interleaved video and audio tokens
|
||||
# in the multimodal region.
|
||||
video_token_id = self.config.video_token_index
|
||||
audio_token_id = self.config.audio_token_index
|
||||
|
||||
is_video = is_multimodal & (input_ids == video_token_id)
|
||||
is_audio = is_multimodal & (input_ids == audio_token_id)
|
||||
|
||||
num_video = is_video.sum().item()
|
||||
num_audio = is_audio.sum().item()
|
||||
|
||||
if check_interleaved_audio_video(is_video, is_audio, num_video, num_audio):
|
||||
return merge_interleaved_embeddings(
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
is_video,
|
||||
is_audio,
|
||||
is_multimodal,
|
||||
num_video,
|
||||
num_audio,
|
||||
)
|
||||
|
||||
# Default: standard merge (no interleaving)
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds, multimodal_embeddings, is_multimodal
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
|
||||
@@ -92,6 +92,8 @@ from .qwen2_5_omni_thinker import (
|
||||
Qwen2_5OmniConditionalGenerationMixin,
|
||||
Qwen2_5OmniThinkerDummyInputsBuilder,
|
||||
Qwen2_5OmniThinkerMultiModalProcessor,
|
||||
check_interleaved_audio_video,
|
||||
merge_interleaved_embeddings,
|
||||
)
|
||||
from .qwen2_5_vl import (
|
||||
Qwen2_5_VisionAttention,
|
||||
@@ -1780,6 +1782,19 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
# Detect interleaved audio-in-video early, since it affects
|
||||
# both the deepstack path and the final embedding merge.
|
||||
video_token_id = self.config.video_token_id
|
||||
audio_token_id = self.config.audio_token_id
|
||||
is_video = is_multimodal & (input_ids == video_token_id)
|
||||
is_audio = is_multimodal & (input_ids == audio_token_id)
|
||||
num_video = is_video.sum().item()
|
||||
num_audio = is_audio.sum().item()
|
||||
|
||||
is_interleaved = check_interleaved_audio_video(
|
||||
is_video, is_audio, num_video, num_audio
|
||||
)
|
||||
|
||||
deepstack_input_embeds = None
|
||||
# split the feat dim to obtain multi-scale visual feature
|
||||
has_vision_embeddings = [
|
||||
@@ -1791,14 +1806,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
):
|
||||
multiscale_len = len(self.visual.deepstack_visual_indexes)
|
||||
multimodal_embeddings_multiscale = []
|
||||
is_vision = torch.zeros_like(is_multimodal)
|
||||
mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
|
||||
mm_position_idx = 0
|
||||
|
||||
if is_interleaved:
|
||||
# Use input_ids-based mask for correct vision positions
|
||||
# when audio and video tokens are interleaved.
|
||||
is_vision = is_video.clone()
|
||||
else:
|
||||
is_vision = torch.zeros_like(is_multimodal)
|
||||
mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
|
||||
mm_position_idx = 0
|
||||
|
||||
for index, embeddings in enumerate(multimodal_embeddings):
|
||||
num_tokens = embeddings.shape[0]
|
||||
current_positions = mm_positions[
|
||||
mm_position_idx : mm_position_idx + num_tokens
|
||||
]
|
||||
|
||||
# Vision embeddings
|
||||
if embeddings.shape[-1] != self.config.text_config.hidden_size:
|
||||
@@ -1809,13 +1828,22 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
)
|
||||
multimodal_embeddings[index] = embeddings_main
|
||||
multimodal_embeddings_multiscale.append(embeddings_multiscale)
|
||||
is_vision[current_positions] = True
|
||||
if not is_interleaved:
|
||||
current_positions = mm_positions[
|
||||
mm_position_idx : mm_position_idx + num_tokens
|
||||
]
|
||||
is_vision[current_positions] = True
|
||||
|
||||
# Audio embeddings
|
||||
else:
|
||||
is_vision[current_positions] = False
|
||||
if not is_interleaved:
|
||||
current_positions = mm_positions[
|
||||
mm_position_idx : mm_position_idx + num_tokens
|
||||
]
|
||||
is_vision[current_positions] = False
|
||||
|
||||
mm_position_idx += num_tokens
|
||||
if not is_interleaved:
|
||||
mm_position_idx += num_tokens
|
||||
|
||||
deepstack_input_embeds = inputs_embeds.new_zeros(
|
||||
inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1)
|
||||
@@ -1834,6 +1862,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
)
|
||||
self._set_deepstack_input_embeds(deepstack_input_embeds)
|
||||
|
||||
if is_interleaved:
|
||||
return merge_interleaved_embeddings(
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
is_video,
|
||||
is_audio,
|
||||
is_multimodal,
|
||||
num_video,
|
||||
num_audio,
|
||||
)
|
||||
|
||||
# Default: standard merge (no interleaving)
|
||||
inputs_embeds = _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
|
||||
Reference in New Issue
Block a user