[Model] Use mm_position to compute mrope positions for Qwen2-VL/2.5-VL (#32126)
Signed-off-by: YunzhuLu <lucia.yunzhu@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -26,7 +26,7 @@
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
@@ -1137,121 +1137,82 @@ class Qwen2VLForConditionalGeneration(
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
def iter_mm_grid_thw(
|
||||
self, mm_features: list[MultiModalFeatureSpec]
|
||||
) -> Iterator[tuple[int, int, int, int, float]]:
|
||||
"""
|
||||
Iterate over multimodal features and yield grid information.
|
||||
|
||||
Args:
|
||||
mm_features: List of multimodal feature specifications
|
||||
|
||||
Yields:
|
||||
Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
|
||||
"""
|
||||
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
|
||||
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
|
||||
offset = mm_feature.mm_position.offset
|
||||
if mm_feature.modality == "image":
|
||||
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
|
||||
assert t == 1, f"Image must have 1 frame, got {t}"
|
||||
yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
|
||||
elif mm_feature.modality == "video":
|
||||
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
|
||||
second_per_grid_ts = 1.0
|
||||
if mm_feature.data.get("second_per_grid_ts", None):
|
||||
second_per_grid_ts = mm_feature.data[
|
||||
"second_per_grid_ts"
|
||||
].data.item()
|
||||
t_factor = second_per_grid_ts * tokens_per_second
|
||||
yield (
|
||||
offset,
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
t_factor,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
mm_features: list[MultiModalFeatureSpec],
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
kwargs = MultiModalFeatureSpec.gather_kwargs(
|
||||
mm_features,
|
||||
{"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
|
||||
)
|
||||
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
|
||||
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
|
||||
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
|
||||
|
||||
hf_config = self.config
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id
|
||||
).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
video_second_per_grid_t = 0.0
|
||||
if remain_images > 0:
|
||||
try:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
except ValueError:
|
||||
ed_image = len(input_tokens) + 1
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if remain_videos > 0:
|
||||
try:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
except ValueError:
|
||||
ed_video = len(input_tokens) + 1
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = image_grid_thw[image_index]
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = video_grid_thw[video_index]
|
||||
video_second_per_grid_t = 1.0
|
||||
if second_per_grid_ts:
|
||||
video_second_per_grid_t = second_per_grid_ts[video_index]
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
text_len = ed - st
|
||||
|
||||
for (
|
||||
offset,
|
||||
llm_grid_t,
|
||||
llm_grid_h,
|
||||
llm_grid_w,
|
||||
t_factor,
|
||||
) in self.iter_mm_grid_thw(mm_features):
|
||||
text_len = offset - st
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
(
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
* video_second_per_grid_t
|
||||
* tokens_per_second
|
||||
)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||
)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
|
||||
if t_factor != 1.0:
|
||||
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
|
||||
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
|
||||
st = offset + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
return torch.from_numpy(llm_positions), mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
|
||||
Reference in New Issue
Block a user