[Model] Pass mm_features directly into get_mrope_input_positions (#28399)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-11-11 21:14:48 +08:00
committed by GitHub
parent 7dbe6d81d6
commit afffd3cc8a
15 changed files with 225 additions and 272 deletions

View File

@@ -27,6 +27,7 @@ from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MultiModalKwargsItems
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalInputs,
MultiModalUUIDDict,
@@ -38,7 +39,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from transformers import BatchFeature, PretrainedConfig
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -367,20 +368,34 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: "PretrainedConfig",
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
if any((second_per_grid_ts, audio_feature_lengths, use_audio_in_video)):
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"audio_feature_lengths",
"use_audio_in_video",
},
)
if any(
v
for k, v in kwargs.items()
if k not in {"image_grid_thw", "video_grid_thw"}
):
raise NotImplementedError("Transformers backend only supports images.")
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
image_grid_thw
)
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(
video_grid_thw
)
mrope_positions, mrope_position_delta = self.model.get_rope_index(
input_ids=torch.tensor(input_tokens).unsqueeze(0),