[Bugfix] Fix Qwen2.5-Omni and Qwen3-Omni mixed-modality embed regression (#35368)

Signed-off-by: linyueqian <linyueqian@outlook.com>
This commit is contained in:
Yueqian Lin
2026-02-26 06:58:23 -05:00
committed by GitHub
parent 01914445b0
commit c0615a296d
3 changed files with 379 additions and 21 deletions

View File

@@ -1376,23 +1376,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
from .utils import _merge_multimodal_embeddings
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if len(multimodal_embeddings) == 0:
return inputs_embeds
# Check for audio-in-video: interleaved video and audio tokens
# in the multimodal region.
# in the multimodal region. Only use the interleaved path when
# needed; otherwise fall back to the default parent implementation.
video_token_id = self.config.video_token_index
audio_token_id = self.config.audio_token_index
@@ -1403,6 +1392,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
num_audio = is_audio.sum().item()
if check_interleaved_audio_video(is_video, is_audio, num_video, num_audio):
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
return merge_interleaved_embeddings(
inputs_embeds,
multimodal_embeddings,
@@ -1413,9 +1408,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
num_audio,
)
# Default: standard merge (no interleaving)
return _merge_multimodal_embeddings(
inputs_embeds, multimodal_embeddings, is_multimodal
# Default: standard merge (no interleaving), same as parent class
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(

View File

@@ -1904,15 +1904,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
num_audio,
)
# Default: standard merge (no interleaving)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
# Default: standard merge (no interleaving), same as parent class.
# multimodal_embeddings may have been updated above (deepstack
# main-scale). Use super() to stay consistent with the parent
# implementation and avoid issues seen in Qwen2.5-Omni (#34506).
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor | None,