[Model] Use mm_position to compute mrope positions for GLM-4.xV (#33039)
Signed-off-by: Yang <lymailforjob@gmail.com>
This commit is contained in:
@@ -5,11 +5,11 @@
|
||||
# https://github.com/zai-org/CogAgent
|
||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||
|
||||
import itertools
|
||||
from argparse import Namespace
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
@@ -624,138 +624,56 @@ class GLM4VForCausalLM(
|
||||
|
||||
return self.transformer.vision(pixel_values)
|
||||
|
||||
def iter_mm_grid_thw(
|
||||
self, mm_features: list[MultiModalFeatureSpec]
|
||||
) -> Iterator[tuple[int, int, int, int]]:
|
||||
hf_config = self.config
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
|
||||
offset = mm_feature.mm_position.offset
|
||||
if mm_feature.modality == "image":
|
||||
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
|
||||
assert t == 1, f"Image must have 1 frame, got {t}"
|
||||
yield offset, t, h // spatial_merge_size, w // spatial_merge_size
|
||||
else:
|
||||
# glm4v only supports image modality
|
||||
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
mm_features: list[MultiModalFeatureSpec],
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
kwargs = MultiModalFeatureSpec.gather_kwargs(
|
||||
mm_features,
|
||||
{"image_grid_thw", "video_grid_thw"},
|
||||
)
|
||||
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
|
||||
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
|
||||
|
||||
hf_config = self.config
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_start_token_id = hf_config.video_start_token_id
|
||||
video_end_token_id = hf_config.video_end_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
llm_pos_ids_list: list = []
|
||||
st = 0
|
||||
for (
|
||||
offset,
|
||||
llm_grid_t,
|
||||
llm_grid_h,
|
||||
llm_grid_w,
|
||||
) in self.iter_mm_grid_thw(mm_features):
|
||||
text_len = offset - st
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)).reshape(
|
||||
3, -1
|
||||
)
|
||||
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
|
||||
# EVA2CLIPModel has embeddings for boi and eoi tokens as well
|
||||
st = offset + 1 + llm_grid_t * llm_grid_h * llm_grid_w + 1
|
||||
|
||||
if image_grid_thw or video_grid_thw:
|
||||
input_token_type: list[str] = []
|
||||
video_check_flg = False
|
||||
for token in input_tokens:
|
||||
if token == video_start_token_id:
|
||||
video_check_flg = True
|
||||
elif token == video_end_token_id:
|
||||
video_check_flg = False
|
||||
if st < len(input_tokens):
|
||||
text_len = len(input_tokens) - st
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
|
||||
if (token == image_token_id) and (video_check_flg is False):
|
||||
input_token_type.append("image")
|
||||
elif (token == image_token_id) and (video_check_flg is True):
|
||||
input_token_type.append("video")
|
||||
else:
|
||||
input_token_type.append("text")
|
||||
|
||||
input_type_group: list[tuple[str, int, int]] = []
|
||||
for key, group_iter in itertools.groupby(
|
||||
enumerate(input_token_type), lambda x: x[1]
|
||||
):
|
||||
group_list = list(group_iter)
|
||||
start_index = group_list[0][0]
|
||||
end_index = group_list[-1][0] + 1
|
||||
input_type_group.append((key, start_index, end_index))
|
||||
|
||||
video_frame_num = 1
|
||||
mm_data_idx = 0
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
)
|
||||
if modality_type == "image":
|
||||
t, h, w = image_grid_thw[mm_data_idx]
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||
)
|
||||
mm_data_idx += 1
|
||||
|
||||
elif modality_type == "video":
|
||||
t, h, w = (
|
||||
video_frame_num,
|
||||
*image_grid_thw[mm_data_idx][1:],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
|
||||
for t_idx in range(llm_grid_t):
|
||||
t_index = (
|
||||
torch.tensor(t_idx)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(1, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(1, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||
)
|
||||
|
||||
mm_data_idx += 1
|
||||
video_frame_num += 1
|
||||
|
||||
else:
|
||||
text_len = end_idx - start_idx
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
video_frame_num = 1
|
||||
|
||||
else:
|
||||
text_len = len(input_tokens)
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
return torch.from_numpy(llm_positions), mrope_position_delta
|
||||
|
||||
embed_input_ids = SupportsMultiModal.embed_input_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user