[Refactor]: Use M-RoPE interface directly while defining model class instead of maintaining model specific M-RoPE implementation in mrope.py (#24172)
Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com> Signed-off-by: dsinghvi <divyanshsinghvi@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
@@ -33,7 +33,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
||||
smart_resize as image_smart_resize,
|
||||
@@ -84,6 +84,7 @@ from vllm.utils import is_list_of
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
@@ -1174,7 +1175,7 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder,
|
||||
)
|
||||
class Qwen3VLForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
@@ -1480,6 +1481,116 @@ class Qwen3VLForConditionalGeneration(
|
||||
)
|
||||
return mm_input_by_modality
|
||||
|
||||
@classmethod
|
||||
def get_mrope_input_positions(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)]
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
vision_start_token_id = hf_config.vision_start_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id
|
||||
).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
if image_token_id in input_tokens and remain_images > 0:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if video_token_id in input_tokens and remain_videos > 0:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||
)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user