[Model] Pass mm_features directly into get_mrope_input_positions (#28399)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -15,7 +15,7 @@ from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||
from transformers import BatchFeature, PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
@@ -36,6 +36,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
@@ -622,25 +623,23 @@ class GLM4VForCausalLM(
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
mm_features: list[MultiModalFeatureSpec],
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value for GLM4V."""
|
||||
kwargs = MultiModalFeatureSpec.gather_kwargs(
|
||||
mm_features,
|
||||
{"image_grid_thw", "video_grid_thw"},
|
||||
)
|
||||
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
|
||||
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
|
||||
|
||||
hf_config = self.config
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_start_token_id = hf_config.video_start_token_id
|
||||
video_end_token_id = hf_config.video_end_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
if not (image_grid_thw is None and video_grid_thw is None):
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
|
||||
if image_grid_thw or video_grid_thw:
|
||||
input_token_type: list[str] = []
|
||||
video_check_flg = False
|
||||
for token in input_tokens:
|
||||
@@ -672,11 +671,7 @@ class GLM4VForCausalLM(
|
||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
)
|
||||
if modality_type == "image":
|
||||
t, h, w = (
|
||||
image_grid_thw[mm_data_idx][0],
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
t, h, w = image_grid_thw[mm_data_idx]
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
@@ -709,8 +704,7 @@ class GLM4VForCausalLM(
|
||||
elif modality_type == "video":
|
||||
t, h, w = (
|
||||
video_frame_num,
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
*image_grid_thw[mm_data_idx][1:],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
|
||||
Reference in New Issue
Block a user