[Bugfix] [Model] Missing MRoPE function definition from KeyeForConditionalGeneration (#27895)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-10-31 22:45:23 -07:00
committed by GitHub
parent 879a06579e
commit e2347dbf58
2 changed files with 254 additions and 17 deletions

View File

@@ -17,7 +17,9 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -56,12 +58,14 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
)
@@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
elif current_platform.is_rocm():
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
@@ -398,18 +405,28 @@ class KeyeSiglipAttention(nn.Module):
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa=False,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
}
def forward(
self,
hidden_states: torch.Tensor,
@@ -457,15 +474,10 @@ class KeyeSiglipAttention(nn.Module):
self.head_dim,
)
if self.attn_backend == _Backend.FLASH_ATTN:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
output = self.flash_attn_varlen_func(
q,
k,
v,
@@ -1542,7 +1554,7 @@ class BaseKeyeModule(nn.Module):
dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
def _build_projector(
self,
@@ -1611,3 +1623,142 @@ class KeyeForConditionalGeneration(
return tuple(
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
)
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
context_len: int = 0,
seq_len: int | None = None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
"""Get mrope input positions and delta value (Keye series)."""
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
)
.long()
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta