[Bugfix][Model] Fix audio-in-video support for Qwen2.5-Omni and Qwen3-Omni (#33605)

Signed-off-by: linyueqian <linyueqian@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Yueqian Lin
2026-02-04 07:15:29 -05:00
committed by GitHub
parent 824058076c
commit f8516a1ab9
2 changed files with 172 additions and 12 deletions

View File

@@ -92,6 +92,8 @@ from .qwen2_5_omni_thinker import (
Qwen2_5OmniConditionalGenerationMixin,
Qwen2_5OmniThinkerDummyInputsBuilder,
Qwen2_5OmniThinkerMultiModalProcessor,
check_interleaved_audio_video,
merge_interleaved_embeddings,
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
@@ -1780,6 +1782,19 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
# Detect interleaved audio-in-video early, since it affects
# both the deepstack path and the final embedding merge.
video_token_id = self.config.video_token_id
audio_token_id = self.config.audio_token_id
is_video = is_multimodal & (input_ids == video_token_id)
is_audio = is_multimodal & (input_ids == audio_token_id)
num_video = is_video.sum().item()
num_audio = is_audio.sum().item()
is_interleaved = check_interleaved_audio_video(
is_video, is_audio, num_video, num_audio
)
deepstack_input_embeds = None
# split the feat dim to obtain multi-scale visual feature
has_vision_embeddings = [
@@ -1791,14 +1806,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
):
multiscale_len = len(self.visual.deepstack_visual_indexes)
multimodal_embeddings_multiscale = []
is_vision = torch.zeros_like(is_multimodal)
mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
mm_position_idx = 0
if is_interleaved:
# Use input_ids-based mask for correct vision positions
# when audio and video tokens are interleaved.
is_vision = is_video.clone()
else:
is_vision = torch.zeros_like(is_multimodal)
mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
mm_position_idx = 0
for index, embeddings in enumerate(multimodal_embeddings):
num_tokens = embeddings.shape[0]
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]
# Vision embeddings
if embeddings.shape[-1] != self.config.text_config.hidden_size:
@@ -1809,13 +1828,22 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
multimodal_embeddings[index] = embeddings_main
multimodal_embeddings_multiscale.append(embeddings_multiscale)
is_vision[current_positions] = True
if not is_interleaved:
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]
is_vision[current_positions] = True
# Audio embeddings
else:
is_vision[current_positions] = False
if not is_interleaved:
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]
is_vision[current_positions] = False
mm_position_idx += num_tokens
if not is_interleaved:
mm_position_idx += num_tokens
deepstack_input_embeds = inputs_embeds.new_zeros(
inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1)
@@ -1834,6 +1862,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
self._set_deepstack_input_embeds(deepstack_input_embeds)
if is_interleaved:
return merge_interleaved_embeddings(
inputs_embeds,
multimodal_embeddings,
is_video,
is_audio,
is_multimodal,
num_video,
num_audio,
)
# Default: standard merge (no interleaving)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,