[Bugfix] Fix Qwen2.5-Omni and Qwen3-Omni mixed-modality embed regression (#35368)
Signed-off-by: linyueqian <linyueqian@outlook.com>
This commit is contained in:
@@ -1376,23 +1376,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
from .utils import _merge_multimodal_embeddings
|
||||
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.get_language_model().embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
# Check for audio-in-video: interleaved video and audio tokens
|
||||
# in the multimodal region.
|
||||
# in the multimodal region. Only use the interleaved path when
|
||||
# needed; otherwise fall back to the default parent implementation.
|
||||
video_token_id = self.config.video_token_index
|
||||
audio_token_id = self.config.audio_token_index
|
||||
|
||||
@@ -1403,6 +1392,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
num_audio = is_audio.sum().item()
|
||||
|
||||
if check_interleaved_audio_video(is_video, is_audio, num_video, num_audio):
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.get_language_model().embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
return merge_interleaved_embeddings(
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
@@ -1413,9 +1408,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
num_audio,
|
||||
)
|
||||
|
||||
# Default: standard merge (no interleaving)
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds, multimodal_embeddings, is_multimodal
|
||||
# Default: standard merge (no interleaving), same as parent class
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -1904,15 +1904,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
num_audio,
|
||||
)
|
||||
|
||||
# Default: standard merge (no interleaving)
|
||||
inputs_embeds = _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
# Default: standard merge (no interleaving), same as parent class.
|
||||
# multimodal_embeddings may have been updated above (deepstack
|
||||
# main-scale). Use super() to stay consistent with the parent
|
||||
# implementation and avoid issues seen in Qwen2.5-Omni (#34506).
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
|
||||
Reference in New Issue
Block a user