diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py index cee83519f..476ca91d9 100644 --- a/examples/offline_inference/qwen2_5_omni/only_thinker.py +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -112,10 +112,36 @@ def get_multi_audios_query() -> QueryResult: ) +def get_multi_images_query() -> QueryResult: + question = "What are the differences between these two images?" + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_bos|><|IMAGE|><|vision_eos|>" + "<|vision_bos|><|IMAGE|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "image": [ + convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"), + convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB"), + ], + }, + }, + limit_mm_per_prompt={ + "image": 2, + }, + ) + + query_map = { "mixed_modalities": get_mixed_modalities_query, "use_audio_in_video": get_use_audio_in_video_query, "multi_audios": get_multi_audios_query, + "multi_images": get_multi_images_query, } diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index bc4a0ecdd..76afd7749 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -22,10 +22,11 @@ # limitations under the License. """Inference-only Qwen2.5-Omni model (thinker part).""" -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from functools import partial from typing import Annotated, Any, Literal +import numpy as np import torch import torch.nn as nn from transformers import PretrainedConfig @@ -85,6 +86,7 @@ from vllm.multimodal.processing.processor import ( PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate, + PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -103,7 +105,6 @@ from .utils import ( maybe_prefix, split_list_into_ranges, ) -from .vision import get_llm_pos_ids_for_vision try: import flash_attn @@ -374,6 +375,67 @@ class Qwen2_5OmniThinkerMultiModalProcessor( self.info.get_hf_config().vision_config.spatial_merge_size )(hf_inputs) + def _derive_audio_from_video_placeholders( + self, + placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_prompt_updates: MultiModalPromptUpdates, + ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: + """ + Helper to derive audio placeholders from video placeholders when + use_audio_in_video=True. + """ + if "video" not in placeholders: + return placeholders + + # Validate audio and video counts match + num_videos = len(placeholders["video"]) + num_audios = len(mm_prompt_updates.get("audio", [])) + if num_audios != num_videos: + raise ValueError( + f"use_audio_in_video requires equal number of audio and video " + f"items, got {num_audios=}, {num_videos=}" + ) + + tokenizer = self.info.get_tokenizer() + processor = self.info.get_hf_processor() + audio_token_id = tokenizer.get_vocab()[processor.audio_token] + video_token_id = tokenizer.get_vocab()[processor.video_token] + + result_placeholders = dict(placeholders) + audio_placeholders = [] + video_placeholders = [] + + # Each video is paired with one audio + for video_idx, video_placeholder in enumerate(placeholders["video"]): + # Create is_embed mask selecting only audio tokens + audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id + + # Create is_embed mask selecting only video tokens + video_is_embed = torch.tensor(video_placeholder.tokens) == video_token_id + + audio_placeholder = PlaceholderFeaturesInfo( + modality="audio", + item_idx=video_idx, + start_idx=video_placeholder.start_idx, + tokens=video_placeholder.tokens, + is_embed=audio_is_embed, + ) + audio_placeholders.append(audio_placeholder) + + # Update video placeholder with is_embed mask + video_placeholder_with_mask = PlaceholderFeaturesInfo( + modality="video", + item_idx=video_idx, + start_idx=video_placeholder.start_idx, + tokens=video_placeholder.tokens, + is_embed=video_is_embed, + ) + video_placeholders.append(video_placeholder_with_mask) + + result_placeholders["audio"] = audio_placeholders + result_placeholders["video"] = video_placeholders + return result_placeholders + def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, @@ -389,6 +451,16 @@ class Qwen2_5OmniThinkerMultiModalProcessor( self._validate_mm_kwargs(mm_kwargs, mm_item_counts) self._validate_mm_updates(mm_prompt_updates, mm_item_counts) + # Detect use_audio_in_video from mm_kwargs + use_audio_in_video = False + if "video" in mm_kwargs: + for item in mm_kwargs["video"]: + if item and item.get("use_audio_in_video"): + use_audio_in_video_tensor = item["use_audio_in_video"].data + if use_audio_in_video_tensor.numel() > 0: + use_audio_in_video = bool(use_audio_in_video_tensor.item()) + break + if is_update_applied: mm_placeholders = self._find_mm_placeholders( prompt_ids, @@ -399,10 +471,25 @@ class Qwen2_5OmniThinkerMultiModalProcessor( mm_item_counts, ) else: - prompt_ids, mm_placeholders = self._apply_prompt_updates( - prompt_ids, - mm_prompt_updates, - ) + if use_audio_in_video and "audio" in mm_prompt_updates: + # Filter out audio updates - they are embedded in video + filtered_updates = { + k: v for k, v in mm_prompt_updates.items() if k != "audio" + } + prompt_ids, mm_placeholders = self._apply_prompt_updates( + prompt_ids, + filtered_updates, + ) + # Derive audio placeholders from video placeholders + mm_placeholders = self._derive_audio_from_video_placeholders( + mm_placeholders, mm_prompt_updates + ) + else: + prompt_ids, mm_placeholders = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + ) + self._validate_mm_placeholders( mm_placeholders, mm_item_counts, @@ -542,13 +629,19 @@ class Qwen2_5OmniThinkerMultiModalProcessor( else: video_second_per_grid_t = 1.0 - return self.omni_get_updates_use_audio_in_video( + updates = self.omni_get_updates_use_audio_in_video( thinker_config=thinker_config, audio_len=audio_num_features, video_grid_thw=video_grid_thw, video_second_per_grid_t=video_second_per_grid_t, ) + # Only video tokens should receive video embeddings + return PromptUpdateDetails.select_token_id( + seq=updates, + embed_token_id=video_token_id, + ) + video_replacement_fn = ( get_replacement_qwen2_use_audio_in_video if use_audio_in_video @@ -889,216 +982,276 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ) return mm_input_by_modality + def _get_audio_for_video_mapping( + self, mm_features: list[MultiModalFeatureSpec] + ) -> tuple[dict[int, int], set[int]]: + """ + Map video offset -> paired audio_feature_length for use_audio_in_video. + + When use_audio_in_video=True, audio is interleaved within video chunks. + The pairing is based on feature order in mm_features. + + Returns: + Tuple of (video_offset -> audio_feature_length mapping, + set of paired audio offsets to skip) + """ + videos_with_audio = [ + f + for f in mm_features + if f.modality == "video" + and f.data.get("use_audio_in_video") + and f.data["use_audio_in_video"].data.item() + ] + audios = [f for f in mm_features if f.modality == "audio"] + + # Pair videos with audio features (assumes matching order) + mapping: dict[int, int] = {} + paired_audio_offsets: set[int] = set() + for i, video_f in enumerate(videos_with_audio): + if i < len(audios): + audio_len = audios[i].data["audio_feature_lengths"].data.item() + mapping[video_f.mm_position.offset] = audio_len + paired_audio_offsets.add(audios[i].mm_position.offset) + return mapping, paired_audio_offsets + + def _compute_audio_token_count(self, audio_feature_length: int) -> int: + """Compute audio tokens from feature length.""" + return ((audio_feature_length - 1) // 2 + 1 - 2) // 2 + 1 + + def iter_mm_features( + self, mm_features: list[MultiModalFeatureSpec] + ) -> Iterator[tuple[int, str, dict[str, Any]]]: + """ + Iterate over multimodal features sorted by position offset. + + Yields: (offset, modality, feature_data) where feature_data contains: + - image: {"grid_t", "grid_h", "grid_w", "t_factor"} + - video: {"grid_t", "grid_h", "grid_w", "t_factor", + "use_audio_in_video", "audio_feature_length"} + - audio: {"audio_feature_length"} + """ + thinker_config = self.config + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr( + thinker_config.vision_config, "tokens_per_second", 25 + ) + + # Sort features by offset first, then pair audio with video + sorted_features = sorted(mm_features, key=lambda f: f.mm_position.offset) + audio_for_video, paired_audio_offsets = self._get_audio_for_video_mapping( + sorted_features + ) + + for mm_feature in sorted_features: + offset = mm_feature.mm_position.offset + modality = mm_feature.modality + + if modality == "image": + t, h, w = mm_feature.data["image_grid_thw"].data.tolist() + yield ( + offset, + "image", + { + "grid_t": t, + "grid_h": h // spatial_merge_size, + "grid_w": w // spatial_merge_size, + "t_factor": 1.0 * tokens_per_second, + }, + ) + elif modality == "video": + t, h, w = mm_feature.data["video_grid_thw"].data.tolist() + second_per_grid_ts = 1.0 + if mm_feature.data.get("second_per_grid_ts"): + second_per_grid_ts = mm_feature.data[ + "second_per_grid_ts" + ].data.item() + use_audio_in_video = False + if mm_feature.data.get("use_audio_in_video"): + use_audio_in_video = bool( + mm_feature.data["use_audio_in_video"].data.item() + ) + + yield ( + offset, + "video", + { + "grid_t": t, + "grid_h": h // spatial_merge_size, + "grid_w": w // spatial_merge_size, + "t_factor": second_per_grid_ts * tokens_per_second, + "use_audio_in_video": use_audio_in_video, + "audio_feature_length": audio_for_video.get(offset), + }, + ) + elif modality == "audio": + # Skip audio that's paired with video (handled in video case) + if offset not in paired_audio_offsets: + audio_len = mm_feature.data["audio_feature_lengths"].data.item() + yield offset, "audio", {"audio_feature_length": audio_len} + + def _compute_interleaved_positions( + self, start_idx: int, data: dict[str, Any] + ) -> tuple[np.ndarray, int]: + """ + Compute positions for interleaved video+audio chunks. + + Returns: (position_ids, total_token_count) + """ + grid_t = data["grid_t"] + grid_h = data["grid_h"] + grid_w = data["grid_w"] + t_factor = data["t_factor"] + audio_len = data["audio_feature_length"] + + thinker_config = self.config + tokens_per_second = getattr( + thinker_config.vision_config, "tokens_per_second", 25 + ) + seconds_per_chunk = thinker_config.seconds_per_chunk + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + + # Temporal indices with scaling + t_index = (np.arange(grid_t) * t_factor).astype(np.int64) + + # Split temporal indices into chunks + t_index_split_chunk: list[list[int]] = [ + [] for _ in range((int(t_index.max()) // t_ntoken_per_chunk) + 1) + ] + for t_val in t_index: + idx = int(t_val) // t_ntoken_per_chunk + t_index_split_chunk[idx].append(int(t_val)) + + pure_audio_len = self._compute_audio_token_count(audio_len) + added_audio_len = 0 + pos_ids_list: list[np.ndarray] = [] + audio_start_idx = start_idx + + for t_chunk in t_index_split_chunk: + if not t_chunk: + continue + + chunk_t = len(t_chunk) + + # Build vision positions for this chunk + h_indices = np.tile( + np.arange(grid_h).reshape(1, -1, 1), (chunk_t, 1, grid_w) + ).flatten() + w_indices = np.tile( + np.arange(grid_w).reshape(1, 1, -1), (chunk_t, grid_h, 1) + ).flatten() + t_indices = np.repeat(np.array(t_chunk), grid_h * grid_w) + + vision_pos = np.stack([t_indices, h_indices, w_indices]) + start_idx + pos_ids_list.append(vision_pos) + + # Audio tokens for this chunk + audio_chunk_size = min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) + if audio_chunk_size > 0: + audio_pos = ( + np.broadcast_to(np.arange(audio_chunk_size), (3, audio_chunk_size)) + + audio_start_idx + ) + pos_ids_list.append(audio_pos) + audio_start_idx = audio_start_idx + audio_chunk_size + added_audio_len += audio_chunk_size + + # Handle remaining audio that doesn't fit in chunks + if added_audio_len < pure_audio_len: + remaining = pure_audio_len - added_audio_len + remaining_audio_pos = ( + np.broadcast_to(np.arange(remaining), (3, remaining)) + audio_start_idx + ) + pos_ids_list.append(remaining_audio_pos) + + # Calculate total token count + vision_tokens = grid_t * grid_h * grid_w + total_tokens = vision_tokens + pure_audio_len + + return np.concatenate(pos_ids_list, axis=1), total_tokens + def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: """ - Example: + Compute M-RoPE input positions using mm_features directly. + Example for use_audio_in_video case: (V_i are vision position ids, A_i are audio position ids) |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... """ - kwargs = MultiModalFeatureSpec.gather_kwargs( - mm_features, - { - "image_grid_thw", - "video_grid_thw", - "second_per_grid_ts", - "audio_feature_lengths", - "use_audio_in_video", - }, - ) - image_grid_thw = kwargs.get("image_grid_thw", []) - video_grid_thw = kwargs.get("video_grid_thw", []) - second_per_grid_ts = kwargs.get("second_per_grid_ts", []) - audio_feature_lengths = kwargs.get("audio_feature_lengths", []) - use_audio_in_video = any(kwargs.get("use_audio_in_video", [])) + llm_pos_ids_list: list[np.ndarray] = [] + st = 0 - image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( - image_grid_thw - ) - video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( - video_grid_thw - ) + for offset, modality, data in self.iter_mm_features(mm_features): + # Add text segment before this feature + text_len = offset - st + st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0 + if text_len > 0: + llm_pos_ids_list.append( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + ) + st_idx += text_len - # TODO(fyabc): refactor and share more code with - # _vl_get_input_positions_tensor. + if modality == "audio": + # Standalone audio positions + audio_tokens = self._compute_audio_token_count( + data["audio_feature_length"] + ) + llm_pos_ids_list.append( + np.broadcast_to(np.arange(audio_tokens), (3, audio_tokens)) + st_idx + ) + st = offset + audio_tokens - thinker_config = self.config - audio_token_id = thinker_config.audio_token_index - image_token_id = thinker_config.image_token_index - video_token_id = thinker_config.video_token_index - audio_start_token_id = thinker_config.audio_start_token_id - audio_end_token_id = thinker_config.audio_end_token_id - vision_start_token_id = thinker_config.vision_start_token_id - vision_end_token_id = thinker_config.vision_end_token_id - seconds_per_chunk = thinker_config.seconds_per_chunk - spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr( - thinker_config.vision_config, "tokens_per_second", 25 - ) + elif modality == "image": + # Image uses np.indices like Qwen2-VL + grid_t = data["grid_t"] + grid_h = data["grid_h"] + grid_w = data["grid_w"] + t_factor = data["t_factor"] - src_item = input_tokens - audio_seqlens = audio_feature_lengths - if not second_per_grid_ts: - second_per_grid_ts = [1] * video_grid_thw.shape[0] - audio_idx = 0 - video_idx = 0 - image_idx = 0 - new_src_item: list[int] = [] - llm_pos_ids_list: list[torch.Tensor] = [] + grid_indices = np.indices((grid_t, grid_h, grid_w)) + if t_factor != 1.0: + grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64) + llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx) + st = offset + grid_t * grid_h * grid_w - idx = 0 - while idx < len(src_item): - new_src_item_len = len(new_src_item) - start_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + elif modality == "video": + grid_t = data["grid_t"] + grid_h = data["grid_h"] + grid_w = data["grid_w"] + t_factor = data["t_factor"] + + if not data["use_audio_in_video"]: + # Simple video (same as Qwen2-VL) + grid_indices = np.indices((grid_t, grid_h, grid_w)) + if t_factor != 1.0: + grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64) + llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx) + st = offset + grid_t * grid_h * grid_w + else: + # Interleaved video+audio + pos_ids, token_count = self._compute_interleaved_positions( + st_idx, data + ) + llm_pos_ids_list.append(pos_ids) + st = offset + token_count + + # Add trailing text + if st < len(input_tokens): + st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]: - if use_audio_in_video and idx > 0: - if ( - src_item[idx] == vision_end_token_id - and src_item[idx - 1] == audio_end_token_id - ): - # processing the <|audio_eos|> before <|vision_eos|> - start_idx -= 1 - elif ( - src_item[idx] == audio_start_token_id - and src_item[idx - 1] == vision_start_token_id - ): - # processing the <|audio_bos|> after <|vision_eos|> - start_idx -= 1 - new_src_item.append(src_item[idx]) - llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) - llm_pos_ids_list.append(llm_pos_ids) - elif src_item[idx] == audio_token_id: - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 - new_src_item.extend([audio_token_id] * place_num) - llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx - llm_pos_ids_list.append(llm_pos_ids) - audio_idx += 1 - elif src_item[idx] == image_token_id: - grid_t = image_grid_thw[image_idx][0] - grid_hs = image_grid_thw[:, 1] - grid_ws = image_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() - llm_pos_ids = get_llm_pos_ids_for_vision( - start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = image_grid_thw[image_idx].prod() // ( - spatial_merge_size**2 - ) - new_src_item.extend([image_token_id] * vision_seqlen) - image_idx += 1 - elif src_item[idx] == video_token_id and not use_audio_in_video: - grid_t = video_grid_thw[video_idx][0] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_index = ( - torch.arange(grid_t) - * second_per_grid_ts[video_idx] - * tokens_per_second - ).long() - llm_pos_ids = get_llm_pos_ids_for_vision( - start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2 - ) - new_src_item.extend([video_token_id] * vision_seqlen) - video_idx += 1 - else: - # read audio from video - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2 - ) - grid_t = video_grid_thw[video_idx][0] - grid_h = video_grid_thw[video_idx][1] - grid_w = video_grid_thw[video_idx][2] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = ( - torch.arange(grid_t) - * second_per_grid_ts[video_idx] - * tokens_per_second - ).long() - t_index_split_chunk = split_list_into_ranges( - t_index, t_ntoken_per_chunk - ) - place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 - pure_audio_len = place_num - 2 - added_audio_len = 0 - audio_llm_pos_ids_list: list[torch.Tensor] = [] - for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = ( - len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) - ) - new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) - vision_llm_pos_ids_list = get_llm_pos_ids_for_vision( - start_idx, - video_idx, - spatial_merge_size, - t_chunk, - grid_hs, - grid_ws, - ).split(1, dim=1) - llm_pos_ids_list.extend(vision_llm_pos_ids_list) - new_src_item.extend( - min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) - * [audio_token_id] - ) - audio_start_idx = ( - start_idx - if len(audio_llm_pos_ids_list) == 0 - else audio_llm_pos_ids_list[-1][0].item() + 1 - ) - if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: - audio_llm_pos_ids_list = ( - torch.arange( - min( - t_ntoken_per_chunk, pure_audio_len - added_audio_len - ) - ).expand(3, -1) - + audio_start_idx - ).split(1, dim=1) - else: - audio_llm_pos_ids_list = [] - added_audio_len += min( - t_ntoken_per_chunk, pure_audio_len - added_audio_len - ) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - if added_audio_len < pure_audio_len: - new_src_item.extend( - (pure_audio_len - added_audio_len) * [audio_token_id] - ) - audio_llm_pos_ids_list = ( - torch.arange(pure_audio_len - added_audio_len).expand(3, -1) - + llm_pos_ids_list[-1].max() - + 1 - ).split(1, dim=1) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - audio_idx += 1 - video_idx += 1 - # move to the next token - idx += len(new_src_item) - new_src_item_len - llm_positions = torch.cat(llm_pos_ids_list, dim=1) - mrope_position_delta = ( - torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) - ) + llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) + mrope_position_delta = int(llm_positions.max()) + 1 - len(input_tokens) - return llm_positions, mrope_position_delta + return torch.from_numpy(llm_positions), mrope_position_delta def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cdf687a7e..64f6263cc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2474,9 +2474,15 @@ class GPUModelRunner( mm_embeds_item = encoder_output[start_idx:end_idx] req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( - True if is_embed is None else is_embed - ) + # OR mask for overlapping mm_features (use_audio_in_video) + if is_embed is None: + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( + True + ) + else: + is_mm_embed[ + req_start_pos + start_idx : req_start_pos + end_idx + ] |= is_embed mm_embeds_req.append(mm_embeds_item) if self.is_multimodal_pruning_enabled and self.uses_mrope: