diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 5152a73de..3b50ae74d 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -113,6 +113,95 @@ except (ImportError, ModuleNotFoundError): logger = init_logger(__name__) +def check_interleaved_audio_video( + is_video: torch.Tensor, + is_audio: torch.Tensor, + num_video: int, + num_audio: int, +) -> bool: + """ + Check if video and audio positions are interleaved in the multimodal region. + + Returns: + True if video and audio tokens are interleaved, False otherwise. + """ + if num_video == 0 or num_audio == 0: + return False + + video_pos = is_video.nonzero(as_tuple=True)[0] + audio_pos = is_audio.nonzero(as_tuple=True)[0] + + return ( + video_pos[0].item() < audio_pos[-1].item() + and audio_pos[0].item() < video_pos[-1].item() + ) + + +def merge_interleaved_embeddings( + inputs_embeds: torch.Tensor, + multimodal_embeddings: "MultiModalEmbeddings", + is_video: torch.Tensor, + is_audio: torch.Tensor, + is_multimodal: torch.Tensor, + num_video: int, + num_audio: int, +) -> torch.Tensor: + """ + Merge embeddings for interleaved audio-in-video sequences. + + When use_audio_in_video=True, video and audio tokens are interleaved in + the token sequence, but embeddings are provided as separate contiguous + tensors (video first, then audio). This function reorders video and audio + embeddings to match sequence position order and scatters them efficiently. + + Args: + inputs_embeds: The input embeddings tensor to merge into. + multimodal_embeddings: List of embedding tensors (video, audio, other). + is_video: Boolean mask for video token positions. + is_audio: Boolean mask for audio token positions. + is_multimodal: Boolean mask for all multimodal token positions. + num_video: Total count of video tokens. + num_audio: Total count of audio tokens. + + Returns: + The merged inputs_embeds tensor with multimodal embeddings scattered + to their correct positions. + """ + # Categorize embeddings by modality based on token counts. + # Embeddings come grouped by modality but order varies (e.g., image, video, audio + # or video, audio depending on input kwargs order). + video_embeds: list[torch.Tensor] = [] + audio_embeds: list[torch.Tensor] = [] + other_embeds: list[torch.Tensor] = [] + video_remaining = num_video + audio_remaining = num_audio + + for emb in multimodal_embeddings: + n = emb.shape[0] + if video_remaining > 0 and n <= video_remaining: + video_embeds.append(emb) + video_remaining -= n + elif audio_remaining > 0 and n <= audio_remaining: + audio_embeds.append(emb) + audio_remaining -= n + else: + other_embeds.append(emb) + + # Scatter each modality to its positions + if video_embeds: + video_positions = is_video.nonzero(as_tuple=True)[0] + inputs_embeds[video_positions] = torch.cat(video_embeds, dim=0) + if audio_embeds: + audio_positions = is_audio.nonzero(as_tuple=True)[0] + inputs_embeds[audio_positions] = torch.cat(audio_embeds, dim=0) + if other_embeds: + other_mask = is_multimodal & ~is_video & ~is_audio + other_positions = other_mask.nonzero(as_tuple=True)[0] + inputs_embeds[other_positions] = torch.cat(other_embeds, dim=0) + + return inputs_embeds + + class Qwen2_5OmniAudioFeatureInputs(TensorSchema): """ Dimensions: @@ -1286,17 +1375,48 @@ class Qwen2_5OmniThinkerForConditionalGeneration( is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: - # This is to satisfy the type checker for each overload + from .utils import _merge_multimodal_embeddings + if multimodal_embeddings is None or is_multimodal is None: return super().embed_input_ids(input_ids) - return super().embed_input_ids( + inputs_embeds = self._embed_text_input_ids( input_ids, - multimodal_embeddings=multimodal_embeddings, + self.get_language_model().embed_input_ids, is_multimodal=is_multimodal, handle_oov_mm_token=handle_oov_mm_token, ) + if len(multimodal_embeddings) == 0: + return inputs_embeds + + # Check for audio-in-video: interleaved video and audio tokens + # in the multimodal region. + video_token_id = self.config.video_token_index + audio_token_id = self.config.audio_token_index + + is_video = is_multimodal & (input_ids == video_token_id) + is_audio = is_multimodal & (input_ids == audio_token_id) + + num_video = is_video.sum().item() + num_audio = is_audio.sum().item() + + if check_interleaved_audio_video(is_video, is_audio, num_video, num_audio): + return merge_interleaved_embeddings( + inputs_embeds, + multimodal_embeddings, + is_video, + is_audio, + is_multimodal, + num_video, + num_audio, + ) + + # Default: standard merge (no interleaving) + return _merge_multimodal_embeddings( + inputs_embeds, multimodal_embeddings, is_multimodal + ) + def forward( self, input_ids: torch.Tensor | None, diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 9500ce2e2..93a17f0c8 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -92,6 +92,8 @@ from .qwen2_5_omni_thinker import ( Qwen2_5OmniConditionalGenerationMixin, Qwen2_5OmniThinkerDummyInputsBuilder, Qwen2_5OmniThinkerMultiModalProcessor, + check_interleaved_audio_video, + merge_interleaved_embeddings, ) from .qwen2_5_vl import ( Qwen2_5_VisionAttention, @@ -1780,6 +1782,19 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds + # Detect interleaved audio-in-video early, since it affects + # both the deepstack path and the final embedding merge. + video_token_id = self.config.video_token_id + audio_token_id = self.config.audio_token_id + is_video = is_multimodal & (input_ids == video_token_id) + is_audio = is_multimodal & (input_ids == audio_token_id) + num_video = is_video.sum().item() + num_audio = is_audio.sum().item() + + is_interleaved = check_interleaved_audio_video( + is_video, is_audio, num_video, num_audio + ) + deepstack_input_embeds = None # split the feat dim to obtain multi-scale visual feature has_vision_embeddings = [ @@ -1791,14 +1806,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ): multiscale_len = len(self.visual.deepstack_visual_indexes) multimodal_embeddings_multiscale = [] - is_vision = torch.zeros_like(is_multimodal) - mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0] - mm_position_idx = 0 + + if is_interleaved: + # Use input_ids-based mask for correct vision positions + # when audio and video tokens are interleaved. + is_vision = is_video.clone() + else: + is_vision = torch.zeros_like(is_multimodal) + mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0] + mm_position_idx = 0 + for index, embeddings in enumerate(multimodal_embeddings): num_tokens = embeddings.shape[0] - current_positions = mm_positions[ - mm_position_idx : mm_position_idx + num_tokens - ] # Vision embeddings if embeddings.shape[-1] != self.config.text_config.hidden_size: @@ -1809,13 +1828,22 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ) multimodal_embeddings[index] = embeddings_main multimodal_embeddings_multiscale.append(embeddings_multiscale) - is_vision[current_positions] = True + if not is_interleaved: + current_positions = mm_positions[ + mm_position_idx : mm_position_idx + num_tokens + ] + is_vision[current_positions] = True # Audio embeddings else: - is_vision[current_positions] = False + if not is_interleaved: + current_positions = mm_positions[ + mm_position_idx : mm_position_idx + num_tokens + ] + is_vision[current_positions] = False - mm_position_idx += num_tokens + if not is_interleaved: + mm_position_idx += num_tokens deepstack_input_embeds = inputs_embeds.new_zeros( inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1) @@ -1834,6 +1862,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ) self._set_deepstack_input_embeds(deepstack_input_embeds) + if is_interleaved: + return merge_interleaved_embeddings( + inputs_embeds, + multimodal_embeddings, + is_video, + is_audio, + is_multimodal, + num_video, + num_audio, + ) + + # Default: standard merge (no interleaving) inputs_embeds = _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings,