[Bugfix][Model] Fix audio-in-video support for Qwen2.5-Omni and Qwen3-Omni (#33605)
Signed-off-by: linyueqian <linyueqian@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -92,6 +92,8 @@ from .qwen2_5_omni_thinker import (
|
||||
Qwen2_5OmniConditionalGenerationMixin,
|
||||
Qwen2_5OmniThinkerDummyInputsBuilder,
|
||||
Qwen2_5OmniThinkerMultiModalProcessor,
|
||||
check_interleaved_audio_video,
|
||||
merge_interleaved_embeddings,
|
||||
)
|
||||
from .qwen2_5_vl import (
|
||||
Qwen2_5_VisionAttention,
|
||||
@@ -1780,6 +1782,19 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
# Detect interleaved audio-in-video early, since it affects
|
||||
# both the deepstack path and the final embedding merge.
|
||||
video_token_id = self.config.video_token_id
|
||||
audio_token_id = self.config.audio_token_id
|
||||
is_video = is_multimodal & (input_ids == video_token_id)
|
||||
is_audio = is_multimodal & (input_ids == audio_token_id)
|
||||
num_video = is_video.sum().item()
|
||||
num_audio = is_audio.sum().item()
|
||||
|
||||
is_interleaved = check_interleaved_audio_video(
|
||||
is_video, is_audio, num_video, num_audio
|
||||
)
|
||||
|
||||
deepstack_input_embeds = None
|
||||
# split the feat dim to obtain multi-scale visual feature
|
||||
has_vision_embeddings = [
|
||||
@@ -1791,14 +1806,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
):
|
||||
multiscale_len = len(self.visual.deepstack_visual_indexes)
|
||||
multimodal_embeddings_multiscale = []
|
||||
is_vision = torch.zeros_like(is_multimodal)
|
||||
mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
|
||||
mm_position_idx = 0
|
||||
|
||||
if is_interleaved:
|
||||
# Use input_ids-based mask for correct vision positions
|
||||
# when audio and video tokens are interleaved.
|
||||
is_vision = is_video.clone()
|
||||
else:
|
||||
is_vision = torch.zeros_like(is_multimodal)
|
||||
mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
|
||||
mm_position_idx = 0
|
||||
|
||||
for index, embeddings in enumerate(multimodal_embeddings):
|
||||
num_tokens = embeddings.shape[0]
|
||||
current_positions = mm_positions[
|
||||
mm_position_idx : mm_position_idx + num_tokens
|
||||
]
|
||||
|
||||
# Vision embeddings
|
||||
if embeddings.shape[-1] != self.config.text_config.hidden_size:
|
||||
@@ -1809,13 +1828,22 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
)
|
||||
multimodal_embeddings[index] = embeddings_main
|
||||
multimodal_embeddings_multiscale.append(embeddings_multiscale)
|
||||
is_vision[current_positions] = True
|
||||
if not is_interleaved:
|
||||
current_positions = mm_positions[
|
||||
mm_position_idx : mm_position_idx + num_tokens
|
||||
]
|
||||
is_vision[current_positions] = True
|
||||
|
||||
# Audio embeddings
|
||||
else:
|
||||
is_vision[current_positions] = False
|
||||
if not is_interleaved:
|
||||
current_positions = mm_positions[
|
||||
mm_position_idx : mm_position_idx + num_tokens
|
||||
]
|
||||
is_vision[current_positions] = False
|
||||
|
||||
mm_position_idx += num_tokens
|
||||
if not is_interleaved:
|
||||
mm_position_idx += num_tokens
|
||||
|
||||
deepstack_input_embeds = inputs_embeds.new_zeros(
|
||||
inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1)
|
||||
@@ -1834,6 +1862,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
)
|
||||
self._set_deepstack_input_embeds(deepstack_input_embeds)
|
||||
|
||||
if is_interleaved:
|
||||
return merge_interleaved_embeddings(
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
is_video,
|
||||
is_audio,
|
||||
is_multimodal,
|
||||
num_video,
|
||||
num_audio,
|
||||
)
|
||||
|
||||
# Default: standard merge (no interleaving)
|
||||
inputs_embeds = _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
|
||||
Reference in New Issue
Block a user