[Model] Pass mm_features directly into get_mrope_input_positions (#28399)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-11-11 21:14:48 +08:00
committed by GitHub
parent 7dbe6d81d6
commit afffd3cc8a
15 changed files with 225 additions and 272 deletions

View File

@@ -15,7 +15,7 @@ from torch import nn
from torch.nn import LayerNorm
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType
from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
@@ -36,6 +36,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
@@ -622,25 +623,23 @@ class GLM4VForCausalLM(
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
hf_config = self.config
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if image_grid_thw or video_grid_thw:
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
@@ -672,11 +671,7 @@ class GLM4VForCausalLM(
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
t, h, w = image_grid_thw[mm_data_idx]
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
@@ -709,8 +704,7 @@ class GLM4VForCausalLM(
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
*image_grid_thw[mm_data_idx][1:],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,