[Refactor]: Use M-RoPE interface directly while defining model class instead of maintaining model specific M-RoPE implementation in mrope.py (#24172)
Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com> Signed-off-by: dsinghvi <divyanshsinghvi@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
@@ -38,7 +38,7 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import SupportsLoRA, SupportsMRoPE, SupportsMultiModal, SupportsPP
|
||||
from .keye import (
|
||||
BaseKeyeModule,
|
||||
BaseMultiModalProcessor,
|
||||
@@ -493,7 +493,7 @@ class KeyeVL1_5DummyInputsBuilder(
|
||||
dummy_inputs=KeyeVL1_5DummyInputsBuilder,
|
||||
)
|
||||
class KeyeVL1_5ForConditionalGeneration(
|
||||
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
|
||||
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
def _build_projector(
|
||||
self,
|
||||
@@ -589,3 +589,143 @@ class KeyeVL1_5ForConditionalGeneration(
|
||||
end = patch_cu_seqlens[idx + 1]
|
||||
new_video_embeds.append(video_embeds[start:end])
|
||||
return tuple(new_video_embeds)
|
||||
|
||||
@classmethod
|
||||
def get_mrope_input_positions(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
|
||||
video_grid_thw = video_grid_thw[0]
|
||||
"""Get mrope input positions and delta value (Keye series)."""
|
||||
|
||||
def split_thw(grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]:
|
||||
"""
|
||||
Split grid_thw along the t dimension.
|
||||
|
||||
Args:
|
||||
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
|
||||
|
||||
Returns:
|
||||
List of [1, h, w] rows, repeated t times for each original row.
|
||||
"""
|
||||
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
|
||||
|
||||
if grid_thw.numel() == 0:
|
||||
return []
|
||||
|
||||
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
|
||||
ones = torch.ones_like(hw[:, :1]) # [N,1]
|
||||
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
|
||||
return out.tolist()
|
||||
|
||||
video_grid_thw = split_thw(video_grid_thw)
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
|
||||
image_nums = len(image_grid_thw)
|
||||
frame_nums = len(video_grid_thw)
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_frames = image_nums, frame_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + frame_nums):
|
||||
if remain_images > 0:
|
||||
try:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
except ValueError:
|
||||
ed_image = len(input_tokens) + 1
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if remain_frames > 0:
|
||||
try:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
except ValueError:
|
||||
ed_video = len(input_tokens) + 1
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_frames -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
(
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||
)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
Reference in New Issue
Block a user