[Refactor]: Use M-RoPE interface directly while defining model class instead of maintaining model specific M-RoPE implementation in mrope.py (#24172)
Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com> Signed-off-by: dsinghvi <divyanshsinghvi@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
# https://github.com/zai-org/CogAgent
|
||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||
|
||||
import itertools
|
||||
from argparse import Namespace
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
@@ -14,7 +15,7 @@ from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import BatchFeature, PreTrainedTokenizer, TensorType
|
||||
from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
@@ -54,6 +55,7 @@ from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
@@ -554,7 +556,9 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
||||
info=GLM4VProcessingInfo,
|
||||
dummy_inputs=GLM4VDummyInputsBuilder,
|
||||
)
|
||||
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP):
|
||||
class GLM4VForCausalLM(
|
||||
ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
@@ -615,6 +619,150 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, Suppo
|
||||
|
||||
return self.transformer.vision(pixel_values)
|
||||
|
||||
@classmethod
|
||||
def get_mrope_input_positions(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value for GLM4V."""
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_start_token_id = hf_config.video_start_token_id
|
||||
video_end_token_id = hf_config.video_end_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
if not (image_grid_thw is None and video_grid_thw is None):
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
|
||||
input_token_type: list[str] = []
|
||||
video_check_flg = False
|
||||
for token in input_tokens:
|
||||
if token == video_start_token_id:
|
||||
video_check_flg = True
|
||||
elif token == video_end_token_id:
|
||||
video_check_flg = False
|
||||
|
||||
if (token == image_token_id) and (video_check_flg is False):
|
||||
input_token_type.append("image")
|
||||
elif (token == image_token_id) and (video_check_flg is True):
|
||||
input_token_type.append("video")
|
||||
else:
|
||||
input_token_type.append("text")
|
||||
|
||||
input_type_group: list[tuple[str, int, int]] = []
|
||||
for key, group_iter in itertools.groupby(
|
||||
enumerate(input_token_type), lambda x: x[1]
|
||||
):
|
||||
group_list = list(group_iter)
|
||||
start_index = group_list[0][0]
|
||||
end_index = group_list[-1][0] + 1
|
||||
input_type_group.append((key, start_index, end_index))
|
||||
|
||||
video_frame_num = 1
|
||||
mm_data_idx = 0
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
)
|
||||
if modality_type == "image":
|
||||
t, h, w = (
|
||||
image_grid_thw[mm_data_idx][0],
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||
)
|
||||
mm_data_idx += 1
|
||||
|
||||
elif modality_type == "video":
|
||||
t, h, w = (
|
||||
video_frame_num,
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
|
||||
for t_idx in range(llm_grid_t):
|
||||
t_index = (
|
||||
torch.tensor(t_idx)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(1, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(1, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx
|
||||
)
|
||||
|
||||
mm_data_idx += 1
|
||||
video_frame_num += 1
|
||||
|
||||
else:
|
||||
text_len = end_idx - start_idx
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
video_frame_num = 1
|
||||
|
||||
else:
|
||||
text_len = len(input_tokens)
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.transformer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user