diff --git a/tests/model_executor/test_qwen3_vl_mrope.py b/tests/model_executor/test_qwen3_vl_mrope.py new file mode 100644 index 000000000..90d9fd6e4 --- /dev/null +++ b/tests/model_executor/test_qwen3_vl_mrope.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses +import random +from dataclasses import dataclass + +import pytest +import torch + +from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalFieldElem, + MultiModalKwargsItem, + PlaceholderRange, +) + + +@pytest.fixture(autouse=True, scope="module") +def _force_cpu_default_device(): + # _get_mrope_input_positions returns CPU tensors (via torch.from_numpy). + # Ensure the default device is CPU so the rest of the test tensors match. + original = torch.get_default_device() + torch.set_default_device("cpu") + yield + torch.set_default_device(original) + + +IMAGE_TOKEN_ID = 999 +VIDEO_TOKEN_ID = 888 +VISION_START_TOKEN_ID = 777 +VISION_END_TOKEN_ID = 778 + + +@dataclass +class DummyVisionConfig: + spatial_merge_size: int = 1 + + +@dataclass +class DummyConfig: + image_token_id: int = IMAGE_TOKEN_ID + video_token_id: int = VIDEO_TOKEN_ID + vision_start_token_id: int = VISION_START_TOKEN_ID + vision_end_token_id: int = VISION_END_TOKEN_ID + vision_config: DummyVisionConfig = dataclasses.field( + default_factory=DummyVisionConfig + ) + + +def make_video_embedding( + t, h, w, interleave_text_tokens: tuple[int, int], video_pruning_rate: float = 0.0 +): + """ + Helper function to make a video embedding for a given video size and pruning rate. + + Args: + t: Number of frames. + h: Number of rows. + w: Number of columns. + interleave_text_tokens: Tuple of minimum and maximum number of text tokens to + interleave with the video. + video_pruning_rate: Pruning rate for the video. + + Returns: + Tuple of (unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask) + """ + unpruned_tokens_sequence = [] + population = list(range(1, 100)) + + for _ in range(t): + num_prefix_tokens = random.randint( + interleave_text_tokens[0], interleave_text_tokens[1] + ) + + prefix_tokens = random.choices(population, k=num_prefix_tokens) + vision_tokens = ( + [VISION_START_TOKEN_ID] + [VIDEO_TOKEN_ID] * h * w + [VISION_END_TOKEN_ID] + ) + + unpruned_tokens_sequence.extend(prefix_tokens) + unpruned_tokens_sequence.extend(vision_tokens) + + unpruned_tokens_sequence = torch.tensor(unpruned_tokens_sequence, dtype=torch.long) + video_token_mask = unpruned_tokens_sequence == VIDEO_TOKEN_ID + + pruning_mask = torch.bernoulli(video_token_mask.float() * video_pruning_rate).bool() # type: ignore[attr-defined] + # Sanity check that we don't prune what should not be pruned. + assert not pruning_mask[~video_token_mask].any() + + retention_mask = ~pruning_mask + pruned_tokens_sequence = unpruned_tokens_sequence[retention_mask] + return unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask + + +@pytest.mark.parametrize("spatial_merge_size", [1, 2]) +@pytest.mark.parametrize("grid_thw", [[3, 8, 7], [128, 10, 12]]) +@pytest.mark.parametrize("num_prefix_tokens", [1, 11]) +@pytest.mark.parametrize("num_suffix_tokens", [0, 7]) +@pytest.mark.parametrize("video_pruning_rate", [0, 0.25, 0.75]) +@pytest.mark.parametrize("interleave_text_tokens", [(0, 0), (1, 4)]) +def test_match_qwen3vl_mrope_evs_on( + spatial_merge_size: int, + num_prefix_tokens: int, + grid_thw: tuple[int, int, int], + num_suffix_tokens: int, + video_pruning_rate: float, + interleave_text_tokens: tuple[int, int], +): + hf_config = DummyConfig() + hf_config.vision_config.spatial_merge_size = spatial_merge_size + + t, h, w = grid_thw + population = list(range(1, 100)) + prefix_tokens = random.choices(population, k=num_prefix_tokens) + suffix_tokens = random.choices(population, k=num_suffix_tokens) + + video_tokens, video_tokens_pruned, retention_mask = make_video_embedding( + t, + h // spatial_merge_size, + w // spatial_merge_size, + interleave_text_tokens=interleave_text_tokens, + video_pruning_rate=video_pruning_rate, + ) + assert len(video_tokens) == len(retention_mask) + + input_tokens = prefix_tokens + video_tokens.tolist() + suffix_tokens + input_tokens_pruned = prefix_tokens + video_tokens_pruned.tolist() + suffix_tokens + + whole_sequence_retention_mask = torch.cat( + [ + torch.ones(len(prefix_tokens), dtype=torch.bool), + retention_mask, + torch.ones(len(suffix_tokens), dtype=torch.bool), + ], + dim=0, + ) + + # Build the GT mrope for unpruned input. + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem( + { + "video_grid_thw": MultiModalFieldElem( + data=torch.tensor(grid_thw), + field=None, # HACK. + ), + } + ), + modality="video", + identifier="DUMMY", + mm_position=PlaceholderRange(offset=0, length=len(input_tokens)), + ) + expected_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions( + input_tokens=input_tokens, + mm_features=[mm_feature], + config=hf_config, + ) + + # Compute mrope for a video-only media (unpruned). + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem( + { + "video_grid_thw": MultiModalFieldElem( + data=torch.tensor(grid_thw), + field=None, # HACK. + ), + } + ), + modality="video", + identifier="DUMMY", + mm_position=PlaceholderRange(offset=0, length=video_tokens.numel()), + ) + video_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions( + input_tokens=video_tokens.tolist(), + mm_features=[mm_feature], + config=hf_config, + ) + video_mrope = video_mrope.permute(1, 0) # [N, 3] + hidden_size = 16 + + is_video_embed = torch.isin( + video_tokens_pruned, torch.tensor([VIDEO_TOKEN_ID], dtype=torch.long) + ) + + expanded_positions = torch.full( + (len(video_tokens_pruned), 5), + fill_value=-100, + device=video_mrope.device, + dtype=torch.long, + ) + expanded_positions[is_video_embed, :3] = video_mrope[retention_mask][is_video_embed] + expanded_positions[~is_video_embed, :3] = video_mrope[retention_mask][ + ~is_video_embed + ] + + is_vision_start = video_tokens_pruned == VISION_START_TOKEN_ID + expanded_positions[..., 3] = is_vision_start + expanded_positions[..., 4] = is_video_embed + + # Check that all positions were filled, since we initialized them as negative. + assert (expanded_positions >= 0).all() + + video_embeddings = torch.empty( + (len(video_tokens_pruned), hidden_size), device=video_mrope.device + ) + + video_embeddings = torch.cat( + [ + video_embeddings, + expanded_positions.float(), + ], + dim=1, + ) + multimodal_embeddings = [video_embeddings] + + expected_mrope_masked = expected_mrope[:, whole_sequence_retention_mask] + + # Initialize computed_mrope with sequential positions for all prefix tokens + computed_mrope = torch.empty((3, len(input_tokens_pruned)), dtype=torch.long) + computed_mrope[:, 0 : len(prefix_tokens)] = expected_mrope[ + :, 0 : len(prefix_tokens) + ] + + # Paranoia check that computed_mrope is wrong. + assert not torch.equal(computed_mrope, expected_mrope_masked) + + _, actual_mrope, _ = Qwen3VLForConditionalGeneration._recompute_mrope_positions( + input_ids=input_tokens_pruned, + multimodal_embeddings=multimodal_embeddings, + mrope_positions=computed_mrope, + num_computed_tokens=len(prefix_tokens), + vision_start_token_id=hf_config.vision_start_token_id, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + ) + + assert torch.equal(actual_mrope, expected_mrope_masked) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3eeefbb3f..cd5c5356e 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -195,6 +195,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): - second_per_grid_ts: The video time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Returned when `videos` is not `None`. + - timestamps: List of timestamp values (in seconds) for each frame + after merging. Length equals the temporal dimension after merging. """ type: Literal["pixel_values_videos"] @@ -214,6 +216,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): TensorShape("nv"), ] + timestamps: list[list[float]] | None = None + class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): """ @@ -232,6 +236,8 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): - second_per_grid_ts: The video time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Returned when `videos` is not `None`. + - timestamps: List of timestamp values (in seconds) for each frame + after merging. Length equals the temporal dimension after merging. """ type: Literal["video_embeds"] @@ -250,6 +256,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): torch.Tensor | None, TensorShape("nv"), ] = None + timestamps: list[list[float]] | None = None Qwen2_5_VLVideoInputs: TypeAlias = ( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index c4c71faf3..aeacd99eb 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -755,6 +755,7 @@ def _create_qwen2vl_field_factory( "video", video_embed_grid_sizes ), video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True), + timestamps=MultiModalFieldConfig.batched("video", keep_on_cpu=True), ) return _qwen2vl_field_config diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index 66d8ff8e1..30823ada1 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -628,6 +628,9 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts): dummy_inputs=Qwen3VLDummyInputsBuilder, ) class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid): + # Qwen3.5 does not support multimodal pruning (EVS). + supports_multimodal_pruning = False + packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | { "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"], "in_proj_ba": ["in_proj_b", "in_proj_a"], @@ -643,10 +646,8 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.video_pruning_rate = multimodal_config.video_pruning_rate - self.is_multimodal_pruning_enabled = ( - multimodal_config.is_multimodal_pruning_enabled() - ) + # Qwen3.5 does not support multimodal pruning (EVS). + self.is_multimodal_pruning_enabled = False with self._mark_tower_model(vllm_config, {"image", "video"}): self.visual = Qwen3_VisionTransformer( @@ -693,6 +694,12 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) return inputs_embeds + def recompute_mrope_positions(self, *args, **kwargs): + raise NotImplementedError( + "Qwen3.5 does not support multimodal pruning (EVS). " + "recompute_mrope_positions should never be called." + ) + def forward( self, input_ids: torch.Tensor, @@ -851,10 +858,8 @@ class Qwen3_5MoeForConditionalGeneration( self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.video_pruning_rate = multimodal_config.video_pruning_rate - self.is_multimodal_pruning_enabled = ( - multimodal_config.is_multimodal_pruning_enabled() - ) + # Qwen3.5 does not support multimodal pruning (EVS). + self.is_multimodal_pruning_enabled = False with self._mark_tower_model(vllm_config, {"image", "video"}): self.visual = Qwen3_VisionTransformer( diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index e5bdbd802..b19811977 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -79,6 +79,7 @@ from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, + MultiModalFieldElem, MultiModalKwargsItem, MultiModalKwargsItems, PlaceholderRange, @@ -93,6 +94,8 @@ from vllm.multimodal.processing import ( PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors +from vllm.tokenizers.protocol import TokenizerLike +from vllm.tokenizers.registry import cached_tokenizer_from_config from vllm.utils.collection_utils import is_list_of from vllm.utils.math_utils import round_up @@ -763,7 +766,6 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): def _get_video_second_idx( self, metadata: dict[str, Any], - out_item: MultiModalKwargsItem, do_sample_frames: bool | None = None, sampled_fps: float | None = None, ) -> list[int]: @@ -956,6 +958,7 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) if videos := mm_data.pop("videos", []): video_grid_thw_lst = [] pixel_values_videos_lst = [] + timestamps_per_video = [] for item in videos: video_array, metadata = item @@ -979,6 +982,14 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) **{k: metadata[k] for k in metadata if k != "do_sample_frames"} ) + # Compute timestamps here where we have access to metadata + timestamps = self.info._get_video_second_idx( + metadata=metadata, + do_sample_frames=video_mm_kwargs["do_sample_frames"], + sampled_fps=video_mm_kwargs.get("fps"), + ) + timestamps_per_video.append(timestamps) + video_mm_data = dict() video_mm_data["videos"] = [[video_array]] video_mm_data["video_metadata"] = [[metadata]] @@ -989,6 +1000,49 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) mm_kwargs=video_mm_kwargs, tok_kwargs=tok_kwargs, ) + + merge_size = processor.video_processor.merge_size + # Get video grid info for EVS calculation. + video_grid_thw = video_outputs["video_grid_thw"] + num_frames = int(video_grid_thw[0, 0]) + tokens_per_frame_base = int(video_grid_thw[0, 1:].prod()) // ( + merge_size**2 + ) + + # Apply EVS if enabled. + video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate + if video_pruning_rate is not None and video_pruning_rate > 0.0: + num_tokens = compute_retained_tokens_count( + tokens_per_frame=tokens_per_frame_base, + num_frames=num_frames, + q=video_pruning_rate, + ) + # Here we just need placeholders that won't actually be replaced - + # we just need to make sure the total number of tokens is correct + # assign all tokens to the first frame. + tokens_per_frame = [num_tokens] + [0] * (num_frames - 1) + select_token_id = False + else: + tokens_per_frame = [tokens_per_frame_base] * num_frames + select_token_id = True + + # Generate the video replacement with EVS-adjusted token counts + tokenizer = self.info.get_tokenizer() + hf_config = self.info.get_hf_config() + video_repl = Qwen3VLMultiModalProcessor.get_video_repl( + tokens_per_frame=tokens_per_frame, + timestamps=timestamps, + tokenizer=tokenizer, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + video_token_id=hf_config.video_token_id, + select_token_id=select_token_id, + ) + + # Convert token IDs to text for the HF processor flow + video_placeholder = tokenizer.decode( + video_repl.full, skip_special_tokens=False + ) input_ids = video_outputs.pop("input_ids") video_placeholder = processor.tokenizer.batch_decode(input_ids)[0] prompt = prompt.replace( @@ -1002,6 +1056,7 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) video_outputs = dict( pixel_values_videos=torch.cat(pixel_values_videos_lst), video_grid_thw=torch.cat(video_grid_thw_lst), + timestamps=timestamps_per_video, ) else: video_outputs = dict() @@ -1057,60 +1112,42 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) grid_thw = out_item["video_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) - video, metadata = mm_items["video"][item_idx] - do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames") sampled_fps = hf_processor_mm_kwargs.get("fps") if is_list_of(sampled_fps, float): sampled_fps = sampled_fps[item_idx] - timestamps = self.info._get_video_second_idx( - metadata, out_item, do_sample_frames, sampled_fps - ) + timestamps = out_item["timestamps"].data assert len(timestamps) == grid_thw[0], ( f"The timestamps length({len(timestamps)}) should be equal " f"video length ({grid_thw[0]})." ) - frames_idx_token = [ - tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False) - for curr_time in timestamps - ] - tokens_per_frame = int(grid_thw[1:].prod()) // merge_length - per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token] + # Compute tokens per frame, with EVS support + num_frames = int(grid_thw[0]) + tokens_per_frame_base = int(grid_thw[1:].prod()) // merge_length video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate if video_pruning_rate is not None and video_pruning_rate > 0.0: - total_retained = compute_retained_tokens_count( - tokens_per_frame, - len(frames_idx_token), - video_pruning_rate, + num_tokens = compute_retained_tokens_count( + tokens_per_frame=tokens_per_frame_base, + num_frames=num_frames, + q=video_pruning_rate, ) - if len(frames_idx_token) == 0: - per_frame_token_counts = [] - elif len(frames_idx_token) == 1: - per_frame_token_counts = [tokens_per_frame] - else: - first_frame_tokens = tokens_per_frame - remaining_tokens = max(total_retained - first_frame_tokens, 0) - base = remaining_tokens // (len(frames_idx_token) - 1) - remainder = remaining_tokens % (len(frames_idx_token) - 1) - per_frame_token_counts = [first_frame_tokens] - for frame_idx in range(1, len(frames_idx_token)): - extra = base + (1 if (frame_idx - 1) < remainder else 0) - per_frame_token_counts.append(extra) + tokens_per_frame = [num_tokens] + [0] * (num_frames - 1) + select_token_id = False + else: + tokens_per_frame = [tokens_per_frame_base] * num_frames + select_token_id = True - placeholder = [] - for frame_idx, timestamp_tokens in enumerate(frames_idx_token): - placeholder.extend(timestamp_tokens) - tokens_this_frame = per_frame_token_counts[ - frame_idx if frame_idx < len(per_frame_token_counts) else -1 - ] - placeholder.extend( - [vision_start_token_id] - + [video_token_id] * tokens_this_frame - + [vision_end_token_id] - ) - return PromptUpdateDetails.select_token_id(placeholder, video_token_id) + return Qwen3VLMultiModalProcessor.get_video_repl( + tokens_per_frame=tokens_per_frame, + timestamps=timestamps, + tokenizer=tokenizer, + vision_start_token_id=vision_start_token_id, + vision_end_token_id=vision_end_token_id, + video_token_id=video_token_id, + select_token_id=select_token_id, + ) return [ PromptReplacement( @@ -1127,6 +1164,69 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) ), ] + @staticmethod + def get_video_repl( + *, + tokens_per_frame: list[int], + timestamps: list[float | int], + tokenizer: TokenizerLike, + vision_start_token_id: int, + vision_end_token_id: int, + video_token_id: int, + select_token_id: bool = False, + ) -> PromptUpdateDetails[list[int]]: + """Build prompt replacement for a video in Qwen3VL format. + + The replacement structure for each frame is: + timestamp_tokens + vision_start_token + video_tokens + vision_end_token + + Args: + tokens_per_frame: Number of video tokens per frame (can vary per frame for + EVS). + timestamps: List of timestamps in seconds for each frame + tokenizer: Tokenizer to encode timestamp strings + vision_start_token_id: Token ID for vision start marker + vision_end_token_id: Token ID for vision end marker + video_token_id: Token ID for video content + + Returns: + PromptUpdateDetails with full token sequence + """ + assert len(timestamps) == len(tokens_per_frame), ( + "timestamps and tokens_per_frame must have the same length" + ) + + # Tokenize timestamp strings independently to avoid tokenizer merging + # tokens across boundaries. + # TODO: switch to `_seq2tokens` which has some caching. + timestamp_token_ids = [ + tokenizer.encode(f"<{timestamp:.1f} seconds>", add_special_tokens=False) + for timestamp in timestamps + ] + + # Build the full token sequence + all_token_ids = [] + for frame_timestamp_ids, num_tokens in zip( + timestamp_token_ids, tokens_per_frame + ): + # Add timestamp tokens + all_token_ids.extend(frame_timestamp_ids) + + # Add vision tokens: vision_start + video_tokens + vision_end + all_token_ids.append(vision_start_token_id) + all_token_ids.extend([video_token_id] * num_tokens) + all_token_ids.append(vision_end_token_id) + + if select_token_id: + return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id) + + # NOTE: we use `from_seq` instead of `select_token_id` because we want all + # tokens in the placeholder to be initially marked as candidates. Then + # in `get_input_embeddings``, we refine the mask to only replace + # `video_token_id` / `image_token_id`` positions with video/image embeddings, + # keeping text embeddings for timestamps and structural tokens. + return PromptUpdateDetails.from_seq(all_token_ids) + @support_torch_compile( dynamic_arg_dims={ @@ -1280,6 +1380,7 @@ class Qwen3VLForConditionalGeneration( multimodal_config = vllm_config.model_config.multimodal_config self.config = config + self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config) self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.video_pruning_rate = multimodal_config.video_pruning_rate @@ -1419,6 +1520,7 @@ class Qwen3VLForConditionalGeneration( video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + timestamps = kwargs.pop("timestamps", None) if pixel_values_videos is None and video_embeds is None: return None @@ -1429,6 +1531,7 @@ class Qwen3VLForConditionalGeneration( pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, + timestamps=timestamps, ) if video_embeds is not None: @@ -1436,6 +1539,7 @@ class Qwen3VLForConditionalGeneration( type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw, + timestamps=timestamps, ) def _process_image_input( @@ -1502,19 +1606,29 @@ class Qwen3VLForConditionalGeneration( Returns: Tuple of image embeddings for each image item. - Resulting embeddings will have extra 4 channels for - computed mrope positions. + Resulting embeddings will have extra 5 channels for + computed mrope positions, consistent with video embeddings. """ - merge_size = self.visual.spatial_merge_size - grid_thw = image_input["image_grid_thw"] - grid_thw_list = grid_thw.tolist() - image_embeds_out = [] - for emb, size in zip(image_embeds_split, grid_thw_list): - positions = compute_mrope_for_media(size, merge_size).to(emb.device) - emb = torch.cat([emb, positions], dim=1) - image_embeds_out.append(emb) - image_embeds_split = image_embeds_out - return tuple(image_embeds_split) + if self.is_multimodal_pruning_enabled: + merge_size = self.visual.spatial_merge_size + grid_thw = image_input["image_grid_thw"] + grid_thw_list = grid_thw.tolist() + image_embeds_out = [] + for emb, size in zip(image_embeds_split, grid_thw_list): + positions = compute_mrope_for_media(size, merge_size).to(emb.device) + positions = torch.cat( + [ + positions, + torch.zeros_like( + positions[:, 0:1] + ), # Dummy extra fifth channel + ], + dim=1, + ) + emb = torch.cat([emb, positions], dim=1) + image_embeds_out.append(emb) + image_embeds_split = tuple(image_embeds_out) + return image_embeds_split def _postprocess_video_embeds_evs( self, @@ -1531,63 +1645,219 @@ class Qwen3VLForConditionalGeneration( Returns: Tuple of video embeddings for each video item. - Resulting embeddings will have extra 4 channels for - computed mrope positions. + Resulting embeddings will have extra 5 channels for computed mrope + positions, and whether the index corresponds to a video embedding. """ grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() merge_size = self.visual.spatial_merge_size - # Cast to long to match the original code - # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa - second_per_grid_ts = video_input.get("second_per_grid_ts") - if second_per_grid_ts is None: - # For Qwen3-VL, second_per_grid_ts might not be available - # Use default value of 1.0 for each video - second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long) - else: - second_per_grid_ts = second_per_grid_ts.long() - tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0) - + # Apply EVS to each video. video_embeds_out = [] - for emb, size, video_second_per_grid_t in zip( - video_embeds_split, grid_thw_list, second_per_grid_ts - ): - # For each video, we compute retention mask using EVS - retention_mask = compute_retention_mask( - emb, - size, - spatial_merge_size=self.visual.spatial_merge_size, - q=self.video_pruning_rate, + for video_idx, (emb, size) in enumerate(zip(video_embeds_split, grid_thw_list)): + # Compute positions. + timestamps = video_input.timestamps[video_idx] + num_frames = len(timestamps) + + t, h, w = size + if self.is_multimodal_pruning_enabled: + # For each video, compute retention mask using EVS. + # retention_mask: [11424]. + retention_mask = compute_retention_mask( + emb, + size, + spatial_merge_size=self.visual.spatial_merge_size, + q=self.video_pruning_rate, + ) + # Apply retention mask. + emb = emb[retention_mask] + + # Calculate the actual number of retained tokens per frame. + num_frames, rows, cols = ( + t, + h // merge_size, + w // merge_size, + ) + retention_mask_thw = retention_mask.reshape(num_frames, rows, cols) + num_tokens_per_frame = ( + retention_mask_thw.sum(dim=(1, 2)).long().tolist() + ) + else: + feature_size = emb.shape[0] // num_frames + num_tokens_per_frame = [feature_size] * num_frames + retention_mask = None + + emb = self._create_final_video_embeddings( + video_embeddings=emb, + num_tokens_per_frame=num_tokens_per_frame, + timestamps=timestamps, + video_grid_thw=size, + retention_mask=retention_mask, ) - # Debug logging for EVS pruning - logger.debug( - "EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, " - "pruning_rate=%.2f, reduction=%.1f%%)", - emb.shape[0], - retention_mask.sum().item(), - size[0], - size[1], - size[2], - self.video_pruning_rate, - (1 - retention_mask.float().mean().item()) * 100, - ) - - positions = compute_mrope_for_media( - size, - merge_size, - tokens_per_second=tokens_per_second, - video_second_per_grid=video_second_per_grid_t.item(), - ).to(emb.device) - - emb = emb[retention_mask] - positions = positions[retention_mask] - emb = torch.cat([emb, positions], dim=1) video_embeds_out.append(emb) + return tuple(video_embeds_out) + def _create_final_video_embeddings( + self, + video_embeddings: torch.Tensor, + num_tokens_per_frame: list[int], + timestamps: list[float], + video_grid_thw: list[int], + retention_mask: torch.Tensor, + ) -> torch.Tensor: + """Create final embeddings that combine video embeddings with + text embeddings of indicator tokens. + + These final embeddings contain: + - Actual video embeddings in positions corresponding to video content + - Text embeddings for indicator tokens (, , and + frame separation text) in their respective positions + + These embeddings will replace the placeholder embeddings to create + input_embeds for the LLM. + """ + device = video_embeddings.device + + # Generate video replacement token IDs using get_video_repl + # This tokenizes each frame separator independently, then uses pre-tokenized + # special tokens to ensure consistent tokenization regardless of + # num_tokens_per_frame values. + video_repl = Qwen3VLMultiModalProcessor.get_video_repl( + tokens_per_frame=num_tokens_per_frame, + tokenizer=self._tokenizer, + timestamps=timestamps, + vision_start_token_id=self.config.vision_start_token_id, + vision_end_token_id=self.config.vision_end_token_id, + video_token_id=self.config.video_token_id, + select_token_id=self.is_multimodal_pruning_enabled, + ) + + repl_token_ids = torch.tensor(video_repl.full, device=device) + embed_token_id = _cached_tensor(self.config.video_token_id, device=device) + is_video_embed = torch.isin(repl_token_ids, embed_token_id) + + # Get text embeddings for indicator tokens (has only `visual_dim``). + text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids) + + if self.use_deepstack: + ( + deepstack_input_embeds, + multimodal_embeddings, + ) = self._compute_deepstack_embeds( + inputs_embeds=text_embeddings, + multimodal_embeddings=[video_embeddings], + is_multimodal=is_video_embed, + ) + else: + deepstack_input_embeds = None + multimodal_embeddings = [video_embeddings] + + merged_embeddings = _merge_multimodal_embeddings( + inputs_embeds=text_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_video_embed, + ) + + to_concat = [merged_embeddings] + if deepstack_input_embeds is not None: + to_concat.append( + deepstack_input_embeds.permute(1, 0, 2).reshape( + deepstack_input_embeds.shape[1], -1 + ) + ) + + expanded_positions = None + if self.is_multimodal_pruning_enabled: + is_vision_start = repl_token_ids.eq(self.config.vision_start_token_id) + expanded_positions = self._get_expanded_positions( + device=merged_embeddings.device, + seq_len=merged_embeddings.shape[0], + video_grid_thw=video_grid_thw, + num_tokens_per_frame=num_tokens_per_frame, + timestamps=timestamps, + is_video_embed=is_video_embed, + is_vision_start=is_vision_start, + retention_mask=retention_mask, + ) + to_concat.append(expanded_positions) + + final_video_embeddings = torch.cat(to_concat, dim=-1) + + return final_video_embeddings + + def _get_expanded_positions( + self, + device, + seq_len, + video_grid_thw, + num_tokens_per_frame, + timestamps, + is_video_embed, + is_vision_start, + retention_mask, + ): + embed_token_id = _cached_tensor(self.config.video_token_id, device=device) + + # Expand positions to match the full sequence length + # (includes both video tokens and indicator tokens) + # Shape: [full_length, 5] where positions are filled for video tokens + # and zeros for indicator tokens. + # Channel 3 flags VISION_START tokens so that + # recompute_mrope_positions can reliably count timestamp tokens + # (even when early frames have all video tokens pruned). + # Channel 4 flags video-embedding tokens. + expanded_positions = torch.zeros( + seq_len, + 5, # [t_index, h_index, w_index, is_vision_start, is_video] + device=device, + dtype=torch.long, + ) + _, h, w = video_grid_thw + merge_size = self.visual.spatial_merge_size + num_frames = len(num_tokens_per_frame) + unpruned_token_ids = Qwen3VLMultiModalProcessor.get_video_repl( + tokens_per_frame=[(h // merge_size) * (w // merge_size)] * num_frames, + tokenizer=self._tokenizer, + timestamps=timestamps, + vision_start_token_id=self.config.vision_start_token_id, + vision_end_token_id=self.config.vision_end_token_id, + video_token_id=self.config.video_token_id, + ).full + unpruned_token_ids_tensor = torch.tensor(unpruned_token_ids, device=device) + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem( + { + "video_grid_thw": MultiModalFieldElem( + data=torch.tensor(video_grid_thw), + field=None, # HACK. + ), + } + ), + modality="video", + identifier="DUMMY", + mm_position=PlaceholderRange(offset=0, length=len(unpruned_token_ids)), + ) + original_mrope = ( + self.get_mrope_input_positions( + input_tokens=unpruned_token_ids, + mm_features=[mm_feature], + )[0] + .to(device) + .permute(1, 0) + ) + full_is_video_embed = unpruned_token_ids_tensor == embed_token_id + expanded_positions[is_video_embed, :3] = original_mrope[full_is_video_embed][ + retention_mask + ] + expanded_positions[~is_video_embed, :3] = original_mrope[~full_is_video_embed] + expanded_positions[..., 3] = is_vision_start + expanded_positions[..., 4] = is_video_embed + + return expanded_positions + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} for input_key in kwargs: @@ -1607,66 +1877,77 @@ class Qwen3VLForConditionalGeneration( ) return mm_input_by_modality - def iter_mm_grid_hw( - self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] - ) -> Iterator[tuple[int, int, int]]: - """ - Iterate over multimodal features and yield grid information. - - For videos with EVS (Efficient Video Sampling) enabled, this function - computes the offset based on the pruned token count rather than relying - on input_tokens.index(), which would fail when tokens are pruned. + @staticmethod + def _iter_mm_grid_hw( + input_tokens: list[int], + mm_features: list[MultiModalFeatureSpec], + video_token_id: int, + vision_start_token_id: int, + vision_end_token_id: int, + spatial_merge_size: int, + ) -> Iterator[tuple[int, int, int, int]]: + """Iterate over multimodal features and yield position info. Args: - input_tokens: List of token IDs in the prompt - mm_features: List of multimodal feature specifications + input_tokens: List of token IDs in the input sequence. + mm_features: List of multimodal feature specifications containing + image/video data and position information. + video_token_id: Token ID used for video tokens. + vision_start_token_id: Token ID marking the start of a vision sequence. + vision_end_token_id: Token ID marking the end of a vision sequence. + spatial_merge_size: Size of the spatial merge operation used to + compute logical grid dimensions from the original feature grid. Yields: - Tuple of (offset, grid_h, grid_w) for each frame/image + offset: Position of the first video/image token in the sequence. + llm_grid_h: Logical grid height (may not match actual token count with EVS). + llm_grid_w: Logical grid width (may not match actual token count with EVS). + actual_num_tokens: Actual number of video/image tokens in the placeholder. """ - video_token_id = self.config.video_token_id - spatial_merge_size = self.config.vision_config.spatial_merge_size for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): offset = mm_feature.mm_position.offset if mm_feature.modality == "image": t, h, w = mm_feature.data["image_grid_thw"].data.tolist() assert t == 1, f"Image must have 1 frame, got {t}" - yield offset, h // spatial_merge_size, w // spatial_merge_size + llm_grid_h = h // spatial_merge_size + llm_grid_w = w // spatial_merge_size + yield offset, llm_grid_h, llm_grid_w, llm_grid_h * llm_grid_w elif mm_feature.modality == "video": t, h, w = mm_feature.data["video_grid_thw"].data.tolist() llm_grid_h = h // spatial_merge_size llm_grid_w = w // spatial_merge_size - # Check if EVS (Efficient Video Sampling) is enabled - is_evs_enabled = ( - hasattr(self, "video_pruning_rate") - and self.video_pruning_rate is not None - and self.video_pruning_rate > 0.0 - ) + for _ in range(t): + # When EVS is enabled, some frames may have 0 video tokens in the + # placeholder. We use `vision_start_token_id` to locate each frame + # since it is always present for every frame. + # We then look for the first `video_token_id` after + # `vision_start_token_id` and before `vision_end_token_id`. + offset = input_tokens.index(vision_start_token_id, offset) + vision_end_offset = input_tokens.index(vision_end_token_id, offset) - if is_evs_enabled: - frame_offsets = self._extract_frame_offsets_from_mask( - mm_feature.mm_position, t - ) - if frame_offsets is not None: - for rel_offset in frame_offsets: - yield offset + rel_offset, llm_grid_h, llm_grid_w - continue + try: + actual_num_tokens = 0 + video_offset = input_tokens.index( + video_token_id, offset, vision_end_offset + ) + # NOTE: looking at the + # `Qwen3VLMultiModalProcessor.get_video_repl` code, we can + # see that we can use the below formula to get the token + # count, since everything in between `video_offset` and + # `vision_end_offset` is populated as `video_token_id`. + # This saves us from manually counting the number tokens + # that match `video_token_id` in between. + actual_num_tokens += vision_end_offset - video_offset + except ValueError: + # No `video_token_id` in this frame (EVS with 0 tokens for + # this frame) -> use `offset + 1`` to move past + # `vision_start_token_id`. + video_offset = offset + 1 - # If EVS is enabled but mask is missing, this indicates a bug - # in the prompt processing pipeline. The is_embed mask should - # always be present when video_pruning_rate > 0. - raise RuntimeError( - f"EVS is enabled (pruning_rate={self.video_pruning_rate}) " - "but is_embed mask is missing from mm_position. " - "This indicates a bug in prompt processing." - ) - else: - # Non-EVS mode: Use original logic with input_tokens.index() - for _ in range(t): - offset = input_tokens.index(video_token_id, offset) - yield offset, llm_grid_h, llm_grid_w - offset += llm_grid_h * llm_grid_w + yield video_offset, llm_grid_h, llm_grid_w, actual_num_tokens + # Move offset past this frame for next iteration. + offset = vision_end_offset + 1 else: raise ValueError(f"Unsupported modality: {mm_feature.modality}") @@ -1771,13 +2052,100 @@ class Qwen3VLForConditionalGeneration( return [len(seg) for seg in segments] + def get_mrope_input_positions( + self, + input_tokens: list[int], + mm_features: list[MultiModalFeatureSpec], + ) -> tuple[torch.Tensor, int]: + return self._get_mrope_input_positions( + input_tokens=input_tokens, + mm_features=mm_features, + config=self.config, + ) + + @staticmethod + def _get_mrope_input_positions( + input_tokens: list[int], + mm_features: list[MultiModalFeatureSpec], + config: Qwen3VLConfig, + ): + llm_pos_ids_list = [] + st = 0 + for ( + offset, + llm_grid_h, + llm_grid_w, + actual_num_tokens, + ) in Qwen3VLForConditionalGeneration._iter_mm_grid_hw( + input_tokens, + mm_features, + video_token_id=config.video_token_id, + vision_start_token_id=config.vision_start_token_id, + vision_end_token_id=config.vision_end_token_id, + spatial_merge_size=config.vision_config.spatial_merge_size, + ): + # Skip frames with 0 tokens (EVS placeholder with tokens lumped elsewhere) + if actual_num_tokens == 0: + continue + + text_len = offset - st + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + ) + + # Check if this is a "lumped placeholder" (all tokens from multiple frames + # assigned to the 0-th frame - see + # `Qwen3VLMultiModalProcessor.get_video_repl`. + expected_tokens_per_frame = llm_grid_h * llm_grid_w + if actual_num_tokens > expected_tokens_per_frame: + # Lumped placeholder: create grid positions for all "logical" frames + # represented. + num_logical_frames = actual_num_tokens // expected_tokens_per_frame + remainder = actual_num_tokens % expected_tokens_per_frame + + # Create positions for complete frames. + for _ in range(num_logical_frames): + grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape( + 3, -1 + ) + llm_pos_ids_list.append(grid_indices + text_len + st_idx) + st_idx = llm_pos_ids_list[-1].max() + 1 + text_len = 0 # No text between frames within the lump + + # Handle remainder tokens if any (partial frame). + # NOTE: this should never be the case. Should we have an assert? + if remainder > 0: + # Create a partial grid - take first 'remainder' positions + full_grid = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) + grid_indices = full_grid[:, :remainder] + llm_pos_ids_list.append(grid_indices + text_len + st_idx) + else: + # Normal case: frame has exactly the expected tokens (after actual EVS + # pruning). + grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) + llm_pos_ids_list.append(grid_indices + text_len + st_idx) + + st = offset + actual_num_tokens + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + ) + + llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return torch.from_numpy(llm_positions), mrope_position_delta + def recompute_mrope_positions( self, input_ids: list[int], - multimodal_embeddings: tuple[torch.Tensor, ...], + multimodal_embeddings: MultiModalEmbeddings, mrope_positions: torch.LongTensor, num_computed_tokens: int, - ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]: + ) -> tuple[MultiModalEmbeddings, torch.Tensor, int]: """ Update part of input mrope positions (starting with num_computed_tokens index). Original mrope_positions are computed @@ -1786,9 +2154,10 @@ class Qwen3VLForConditionalGeneration( mrope_positions before we feed it to LLM. Args: - input_ids: (N,) All input tokens of the prompt (Containing - entire sequence). - multimodal_embeddings: Tuple of multimodal embeddings. + input_ids: (N,) All input tokens of the prompt containing + entire sequence. + multimodal_embeddings: Tuple of multimodal embeddings that + fits into the prefill chunk that is being processed. mrope_positions: Existing mrope positions (3, N) for entire sequence num_computed_tokens: A number of computed tokens so far. @@ -1797,10 +2166,26 @@ class Qwen3VLForConditionalGeneration( Tuple of (multimodal_embeddings, mrope_positions, mrope_position_delta). """ - image_token_id = self.config.image_token_id - video_token_id = self.config.video_token_id - vision_start_token_id = self.config.vision_start_token_id + return self._recompute_mrope_positions( + input_ids=input_ids, + multimodal_embeddings=multimodal_embeddings, + mrope_positions=mrope_positions, + num_computed_tokens=num_computed_tokens, + image_token_id=self.config.image_token_id, + video_token_id=self.config.video_token_id, + vision_start_token_id=self.config.vision_start_token_id, + ) + @staticmethod + def _recompute_mrope_positions( + input_ids: list[int], + multimodal_embeddings: MultiModalEmbeddings, + mrope_positions: torch.LongTensor, + num_computed_tokens: int, + vision_start_token_id: int, + image_token_id: int, + video_token_id: int, + ) -> tuple[MultiModalEmbeddings, torch.Tensor, int]: # Device device = ( multimodal_embeddings[0].device @@ -1811,10 +2196,21 @@ class Qwen3VLForConditionalGeneration( # Tensors input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) - mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] - mm_embeddings_pos = [ - mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings - ] + mm_embeddings_out = [] + mm_embeddings_pos = [] + # Strip position information from embeddings (last 5 channels) + # For Qwen3 VL, handle potentially empty frames (from unpacking) + for mm in multimodal_embeddings: + if mm.shape[0] > 0: # Only process non-empty frames + mm_embeddings_out.append(mm[:, :-5]) + mm_embeddings_pos.append(mm[:, -5:].permute(1, 0).long()) + else: + # Empty frame - keep as is + mm_embeddings_out.append(mm) + # Create empty position tensor with correct shape + mm_embeddings_pos.append( + torch.empty(5, 0, device=device, dtype=torch.long) + ) positions, mrope_positions_delta = recompute_mrope_positions( input_ids_t, @@ -1828,107 +2224,14 @@ class Qwen3VLForConditionalGeneration( return tuple(mm_embeddings_out), positions, mrope_positions_delta - def get_mrope_input_positions( - self, - input_tokens: list[int], - mm_features: list[MultiModalFeatureSpec], - ) -> tuple[torch.Tensor, int]: - # Pre-collect actual frame token counts for EVS mode - frame_token_counts_map = {} - for mm_feature in mm_features: - if mm_feature.modality == "video": - is_evs_enabled = ( - hasattr(self, "video_pruning_rate") - and self.video_pruning_rate is not None - and self.video_pruning_rate > 0.0 - ) - if is_evs_enabled: - t = mm_feature.data["video_grid_thw"].data.tolist()[0] - token_counts = self._get_actual_frame_token_counts( - mm_feature.mm_position, t - ) - assert token_counts is not None, ( - "EVS enabled but failed to extract frame token counts " - "from is_embed mask" - ) - frame_token_counts_map[mm_feature.mm_position.offset] = token_counts - - llm_pos_ids_list = [] - st = 0 - frame_counts_idx = {} - - for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( - input_tokens, mm_features - ): - text_len = offset - st - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - - # Determine actual token count for this frame - base_offset = None - for feat_offset in frame_token_counts_map: - if offset >= feat_offset: - base_offset = feat_offset - - if base_offset is not None: - # EVS mode: use actual token count from is_embed mask - assert base_offset in frame_token_counts_map, ( - f"Found base_offset {base_offset} but not in frame_token_counts_map" - ) - - if base_offset not in frame_counts_idx: - frame_counts_idx[base_offset] = 0 - - counts = frame_token_counts_map[base_offset] - idx = frame_counts_idx[base_offset] - - assert idx < len(counts), ( - f"EVS frame index {idx} out of range (total frames: {len(counts)})" - ) - - actual_frame_tokens = counts[idx] - frame_counts_idx[base_offset] += 1 - else: - # Non-EVS mode (or image): use theoretical grid size - actual_frame_tokens = llm_grid_h * llm_grid_w - - # Add text segment - text_positions = ( - np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx - ) - llm_pos_ids_list.append(text_positions) - st_idx += text_len - - # Add frame segment with actual token count (not theoretical) - grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) - # Only take the first actual_frame_tokens positions - frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx - llm_pos_ids_list.append(frame_positions) - - # Update st using actual token count - st = offset + actual_frame_tokens - - # Handle final text segment - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - final_text_positions = ( - np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx - ) - llm_pos_ids_list.append(final_text_positions) - - llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - - return torch.from_numpy(llm_positions), mrope_position_delta - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). - multimodal_embeddings: tuple[torch.Tensor, ...] = () + # tensor corresponding to a multimodal data item (image or video). + multimodal_embeddings: list[torch.Tensor] = [] # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. @@ -1936,19 +2239,20 @@ class Qwen3VLForConditionalGeneration( multimodal_input = mm_input_by_modality[modality] if modality == "image": image_embeddings = self._process_image_input(multimodal_input) - if self.is_multimodal_pruning_enabled: - image_embeddings = self._postprocess_image_embeds_evs( - image_embeddings, multimodal_input - ) - multimodal_embeddings += tuple(image_embeddings) + image_embeddings = self._postprocess_image_embeds_evs( + image_embeddings, multimodal_input + ) + multimodal_embeddings.extend(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) if self.is_multimodal_pruning_enabled: video_embeddings = self._postprocess_video_embeds_evs( video_embeddings, multimodal_input ) - multimodal_embeddings += tuple(video_embeddings) - return multimodal_embeddings + multimodal_embeddings.extend(video_embeddings) + + embeddings_tuple = tuple(multimodal_embeddings) + return embeddings_tuple def _compute_deepstack_embeds( self, @@ -2128,3 +2432,8 @@ class Qwen3VLForConditionalGeneration( vision_config = hf_config.vision_config merge_size = vision_config.spatial_merge_size return num_vision_tokens // merge_size**2 + + +@lru_cache +def _cached_tensor(x, device) -> torch.Tensor: + return torch.tensor(x, device=device) diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index 80815616b..e6fc7d409 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors +from vllm.tokenizers.registry import cached_tokenizer_from_config from .interfaces import MixtureOfExperts from .qwen3_moe import ( @@ -415,6 +416,7 @@ class Qwen3VLMoeForConditionalGeneration( multimodal_config = vllm_config.model_config.multimodal_config self.config = config + self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config) self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.video_pruning_rate = multimodal_config.video_pruning_rate diff --git a/vllm/multimodal/evs.py b/vllm/multimodal/evs.py index 8a36ea415..62611c897 100644 --- a/vllm/multimodal/evs.py +++ b/vllm/multimodal/evs.py @@ -170,9 +170,9 @@ def recompute_mrope_positions( multimodal_embeddings may contain zero, some or even some part of all multimodal_embeddings for a given prompt. - Each multimodal_positions has 4 extra channels - (First 3 channels corresponds to original 3 mrope positions, last channel - is the maximum width of the media repeated). Provided multimodal_positions + Each multimodal_positions has 4 or 5 extra channels + (first 3 channels correspond to the original 3 mrope positions; + remaining channels vary by model — see below). Provided multimodal_positions do not reflect location of media position in sequence - they are computed like the media is in the 0-th position in the sequence. @@ -186,6 +186,16 @@ def recompute_mrope_positions( Args: input_ids: (N,) All input tokens of the prompt (entire sequence). multimodal_positions: List of mrope positions for each media. + If a given element is of shape (4, N), it is assumed to only describe + positions for video / image embeddings. This is the case of e.g. Qwen2.5 VL, + where each multimodal input is a contiguous chunk of embeddings. + The expected channels are [t, h, w, max_width]. + If it is of shape (5, N), it is assumed to possibly describe positions for + both video / image embeddings, as well as text embeddings. This is the case + of e.g. Qwen3 VL, where each video inputs are comprised of individual + frames' embeddings, interleaved with embeddings for timestamp tokens, + and vision start / end tokens. The expected channels are + [t, h, w, is_vision_start, is_vision]. mrope_positions: Existing mrope positions (4, N) for entire sequence. num_computed_tokens: A number of computed tokens so far. vision_start_token_id: Token indicating start of vision media. @@ -233,6 +243,21 @@ def recompute_mrope_positions( # - Current prefill chunk has no vision start indexes at all # - Vision start token appeared in previous prefill round # - Regular case + has_video_tokens = False + num_timestamp_tokens = 0 + if mm_pos.shape[0] == 5 and mm_pos.shape[1] > 0: + # mm_pos[4, :] indicates which positions are for video embeddings. + # If there are no video embeddings, skip timestamp adjustment. + has_video_tokens = torch.any(mm_pos[4, :]).item() + if has_video_tokens: + # Channel 3 flags VISION_START tokens. Timestamp tokens + # precede the first VISION_START, so its index gives us the + # exact timestamp count. This is robust even when early + # frames have all their video tokens pruned (which would + # push argmax(channel 4) far into a later frame). + first_vs = (mm_pos[3, :] == 1).nonzero(as_tuple=True)[0] + num_timestamp_tokens = first_vs[0].item() if len(first_vs) > 0 else 0 + seen_vision_start_indices = vision_start_indices[ vision_start_indices < num_computed_tokens ] @@ -249,6 +274,18 @@ def recompute_mrope_positions( in_the_middle_of_media = ( seen_mm_tokens > seem_mm_tokens_before_last_vision_start ) + # For Qwen3 VL, we can be inside a media segment even before any + # video tokens appear (timestamp tokens are text). If we've passed + # the last vision_start token but haven't reached the first video + # embedding, treat this as "in the middle of media". + if ( + not in_the_middle_of_media + and has_video_tokens + and num_computed_tokens > last_vision_start_token + and num_computed_tokens + <= last_vision_start_token + num_timestamp_tokens + 1 + ): + in_the_middle_of_media = True if in_the_middle_of_media: mm_embeddings_seen = ( @@ -274,14 +311,39 @@ def recompute_mrope_positions( mm_embeddings_seen = 0 global_mm_start = next_vision_start_token - # Offset right after vision_start_token - base = positions[-1, global_mm_start] + 1 - local_start = global_mm_start + 1 + mm_embeddings_seen + # For Qwen3 VL, mm_pos includes timestamp tokens before vision_start + # when starting a new media. Adjust global_mm_start to point to where + # the sequence actually begins (before timestamp tokens). + adjusted_for_timestamps = False + if mm_pos.shape[0] == 5 and mm_embeddings_seen == 0 and has_video_tokens: + # NOTE: -1 is because there is a vision start token right after + # timestamp tokens before any video embeddings appear. + + # Adjust global_mm_start to point to the first timestamp token + # instead of the vision_start token. + global_mm_start -= num_timestamp_tokens + adjusted_for_timestamps = True + + # Offset calculation depends on whether we adjusted for timestamp tokens + if adjusted_for_timestamps: + # Start from position before the first timestamp token + base = positions[-1, global_mm_start - 1] + 1 + local_start = global_mm_start + mm_embeddings_seen + else: + # Original logic: start after vision_start_token + base = positions[-1, global_mm_start] + 1 + local_start = global_mm_start + 1 + mm_embeddings_seen + local_end = local_start + mm_pos.shape[1] positions[:, local_start:local_end] = mm_pos[0:3] + base - # mm_pos[3, 0] is the max width of the media - offset = mm_pos[3, 0] + base + # For Qwen3 VL (5-channel), use the maximum position reached across + # all tokens (both video and text) in all dimensions (t, h, w). + # For Qwen2.5 VL (4-channel), mm_pos[3, 0] is the max width. + if mm_pos.shape[0] == 5: + offset = mm_pos[0:3, :].max() + base + 1 + else: + offset = mm_pos[3, 0] + base text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)