[Model] Use mm_position to compute mrope positions for Qwen2-VL/2.5-VL (#32126)

Signed-off-by: YunzhuLu <lucia.yunzhu@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
YunzhuLu
2026-01-13 17:04:29 +08:00
committed by GitHub
parent df7e12715f
commit 542a4059b2
2 changed files with 113 additions and 190 deletions

View File

@@ -26,7 +26,7 @@
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, TypeAlias
@@ -1137,121 +1137,82 @@ class Qwen2VLForConditionalGeneration(
supports_encoder_tp_data = True
def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int, float]]:
"""
Iterate over multimodal features and yield grid information.
Args:
mm_features: List of multimodal feature specifications
Yields:
Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
"""
spatial_merge_size = self.config.vision_config.spatial_merge_size
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
second_per_grid_ts = 1.0
if mm_feature.data.get("second_per_grid_ts", None):
second_per_grid_ts = mm_feature.data[
"second_per_grid_ts"
].data.item()
t_factor = second_per_grid_ts * tokens_per_second
yield (
offset,
t,
h // spatial_merge_size,
w // spatial_merge_size,
t_factor,
)
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id
).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_videos > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = video_grid_thw[video_index]
video_second_per_grid_t = 1.0
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
t_factor,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
* video_second_per_grid_t
* tokens_per_second
)
.long()
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
if t_factor != 1.0:
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
st = offset + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
return torch.from_numpy(llm_positions), mrope_position_delta
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: