[Model] Pass mm_features directly into get_mrope_input_positions (#28399)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-11-11 21:14:48 +08:00
committed by GitHub
parent 7dbe6d81d6
commit afffd3cc8a
15 changed files with 225 additions and 272 deletions

View File

@@ -65,7 +65,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargsItems
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
@@ -1414,39 +1414,48 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
config = hf_config.thinker_config
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"audio_feature_lengths",
"use_audio_in_video",
},
)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
audio_feature_lengths = kwargs.get("audio_feature_lengths", [])
use_audio_in_video = any(kwargs.get("use_audio_in_video", []))
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
image_grid_thw
)
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(
video_grid_thw
)
input_ids = torch.tensor(input_tokens)
if input_ids is None or input_ids.ndim != 1:
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
seq_len = input_ids.shape[0]
if audio_feature_lengths is not None and not isinstance(
audio_feature_lengths, torch.Tensor
):
audio_feature_lengths = torch.as_tensor(
if isinstance(audio_feature_lengths, list):
audio_feature_lengths = torch.tensor(
audio_feature_lengths, dtype=torch.long
)
if second_per_grid_ts is None:
if video_grid_thw is not None and video_grid_thw.numel() > 0:
second_per_grids = torch.ones(
video_grid_thw.shape[0], dtype=torch.float32
)
else:
second_per_grids = torch.tensor([], dtype=torch.float32)
if not len(second_per_grid_ts) and len(video_grid_thw):
second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32)
else:
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
config = self.config
spatial_merge_size = config.vision_config.spatial_merge_size
image_token_id = config.image_token_id
video_token_id = config.video_token_id