[Model] Pass mm_features directly into get_mrope_input_positions (#28399)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-11-11 21:14:48 +08:00
committed by GitHub
parent 7dbe6d81d6
commit afffd3cc8a
15 changed files with 225 additions and 272 deletions

View File

@@ -35,7 +35,7 @@ import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BatchFeature, PretrainedConfig
from transformers import BatchFeature
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
@@ -75,7 +75,11 @@ from vllm.multimodal.evs import (
compute_retention_mask,
recompute_mrope_positions,
)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalKwargs,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.sequence import IntermediateTensors
@@ -1120,15 +1124,17 @@ class Qwen2_5_VLForConditionalGeneration(
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
second_per_grid_ts: list[float],
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
@@ -1165,20 +1171,12 @@ class Qwen2_5_VLForConditionalGeneration(
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
t, h, w = video_grid_thw[video_index]
video_second_per_grid_t = 1.0
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[video_index]