[Bugfix][Model] Fix audio-in-video support for Qwen2.5-Omni and Qwen3-Omni (#33605)

Signed-off-by: linyueqian <linyueqian@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Yueqian Lin
2026-02-04 07:15:29 -05:00
committed by GitHub
parent 824058076c
commit f8516a1ab9
2 changed files with 172 additions and 12 deletions

View File

@@ -113,6 +113,95 @@ except (ImportError, ModuleNotFoundError):
logger = init_logger(__name__)
def check_interleaved_audio_video(
is_video: torch.Tensor,
is_audio: torch.Tensor,
num_video: int,
num_audio: int,
) -> bool:
"""
Check if video and audio positions are interleaved in the multimodal region.
Returns:
True if video and audio tokens are interleaved, False otherwise.
"""
if num_video == 0 or num_audio == 0:
return False
video_pos = is_video.nonzero(as_tuple=True)[0]
audio_pos = is_audio.nonzero(as_tuple=True)[0]
return (
video_pos[0].item() < audio_pos[-1].item()
and audio_pos[0].item() < video_pos[-1].item()
)
def merge_interleaved_embeddings(
inputs_embeds: torch.Tensor,
multimodal_embeddings: "MultiModalEmbeddings",
is_video: torch.Tensor,
is_audio: torch.Tensor,
is_multimodal: torch.Tensor,
num_video: int,
num_audio: int,
) -> torch.Tensor:
"""
Merge embeddings for interleaved audio-in-video sequences.
When use_audio_in_video=True, video and audio tokens are interleaved in
the token sequence, but embeddings are provided as separate contiguous
tensors (video first, then audio). This function reorders video and audio
embeddings to match sequence position order and scatters them efficiently.
Args:
inputs_embeds: The input embeddings tensor to merge into.
multimodal_embeddings: List of embedding tensors (video, audio, other).
is_video: Boolean mask for video token positions.
is_audio: Boolean mask for audio token positions.
is_multimodal: Boolean mask for all multimodal token positions.
num_video: Total count of video tokens.
num_audio: Total count of audio tokens.
Returns:
The merged inputs_embeds tensor with multimodal embeddings scattered
to their correct positions.
"""
# Categorize embeddings by modality based on token counts.
# Embeddings come grouped by modality but order varies (e.g., image, video, audio
# or video, audio depending on input kwargs order).
video_embeds: list[torch.Tensor] = []
audio_embeds: list[torch.Tensor] = []
other_embeds: list[torch.Tensor] = []
video_remaining = num_video
audio_remaining = num_audio
for emb in multimodal_embeddings:
n = emb.shape[0]
if video_remaining > 0 and n <= video_remaining:
video_embeds.append(emb)
video_remaining -= n
elif audio_remaining > 0 and n <= audio_remaining:
audio_embeds.append(emb)
audio_remaining -= n
else:
other_embeds.append(emb)
# Scatter each modality to its positions
if video_embeds:
video_positions = is_video.nonzero(as_tuple=True)[0]
inputs_embeds[video_positions] = torch.cat(video_embeds, dim=0)
if audio_embeds:
audio_positions = is_audio.nonzero(as_tuple=True)[0]
inputs_embeds[audio_positions] = torch.cat(audio_embeds, dim=0)
if other_embeds:
other_mask = is_multimodal & ~is_video & ~is_audio
other_positions = other_mask.nonzero(as_tuple=True)[0]
inputs_embeds[other_positions] = torch.cat(other_embeds, dim=0)
return inputs_embeds
class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
"""
Dimensions:
@@ -1286,17 +1375,48 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
from .utils import _merge_multimodal_embeddings
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
return super().embed_input_ids(
inputs_embeds = self._embed_text_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if len(multimodal_embeddings) == 0:
return inputs_embeds
# Check for audio-in-video: interleaved video and audio tokens
# in the multimodal region.
video_token_id = self.config.video_token_index
audio_token_id = self.config.audio_token_index
is_video = is_multimodal & (input_ids == video_token_id)
is_audio = is_multimodal & (input_ids == audio_token_id)
num_video = is_video.sum().item()
num_audio = is_audio.sum().item()
if check_interleaved_audio_video(is_video, is_audio, num_video, num_audio):
return merge_interleaved_embeddings(
inputs_embeds,
multimodal_embeddings,
is_video,
is_audio,
is_multimodal,
num_video,
num_audio,
)
# Default: standard merge (no interleaving)
return _merge_multimodal_embeddings(
inputs_embeds, multimodal_embeddings, is_multimodal
)
def forward(
self,
input_ids: torch.Tensor | None,

View File

@@ -92,6 +92,8 @@ from .qwen2_5_omni_thinker import (
Qwen2_5OmniConditionalGenerationMixin,
Qwen2_5OmniThinkerDummyInputsBuilder,
Qwen2_5OmniThinkerMultiModalProcessor,
check_interleaved_audio_video,
merge_interleaved_embeddings,
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
@@ -1780,6 +1782,19 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
# Detect interleaved audio-in-video early, since it affects
# both the deepstack path and the final embedding merge.
video_token_id = self.config.video_token_id
audio_token_id = self.config.audio_token_id
is_video = is_multimodal & (input_ids == video_token_id)
is_audio = is_multimodal & (input_ids == audio_token_id)
num_video = is_video.sum().item()
num_audio = is_audio.sum().item()
is_interleaved = check_interleaved_audio_video(
is_video, is_audio, num_video, num_audio
)
deepstack_input_embeds = None
# split the feat dim to obtain multi-scale visual feature
has_vision_embeddings = [
@@ -1791,14 +1806,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
):
multiscale_len = len(self.visual.deepstack_visual_indexes)
multimodal_embeddings_multiscale = []
is_vision = torch.zeros_like(is_multimodal)
mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
mm_position_idx = 0
if is_interleaved:
# Use input_ids-based mask for correct vision positions
# when audio and video tokens are interleaved.
is_vision = is_video.clone()
else:
is_vision = torch.zeros_like(is_multimodal)
mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
mm_position_idx = 0
for index, embeddings in enumerate(multimodal_embeddings):
num_tokens = embeddings.shape[0]
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]
# Vision embeddings
if embeddings.shape[-1] != self.config.text_config.hidden_size:
@@ -1809,13 +1828,22 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
multimodal_embeddings[index] = embeddings_main
multimodal_embeddings_multiscale.append(embeddings_multiscale)
is_vision[current_positions] = True
if not is_interleaved:
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]
is_vision[current_positions] = True
# Audio embeddings
else:
is_vision[current_positions] = False
if not is_interleaved:
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]
is_vision[current_positions] = False
mm_position_idx += num_tokens
if not is_interleaved:
mm_position_idx += num_tokens
deepstack_input_embeds = inputs_embeds.new_zeros(
inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1)
@@ -1834,6 +1862,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
self._set_deepstack_input_embeds(deepstack_input_embeds)
if is_interleaved:
return merge_interleaved_embeddings(
inputs_embeds,
multimodal_embeddings,
is_video,
is_audio,
is_multimodal,
num_video,
num_audio,
)
# Default: standard merge (no interleaving)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,