[PERF] Speed up Qwen2.5-VL model by speed up rotary position embedding (#17973)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
This commit is contained in:
@@ -25,7 +25,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from functools import partial
|
from functools import lru_cache, partial
|
||||||
from typing import Callable, Literal, Optional, TypedDict, Union
|
from typing import Callable, Literal, Optional, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -478,8 +478,8 @@ class Qwen2_5_VisionRotaryEmbedding(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.theta = theta
|
self.theta = theta
|
||||||
inv_freq = 1.0 / (theta
|
inv_freq = 1.0 / (theta**(
|
||||||
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
torch.arange(0, dim, 2, dtype=torch.float, device='cpu') / dim))
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self._seq_len_cached = 0
|
self._seq_len_cached = 0
|
||||||
self._freqs_cached = None
|
self._freqs_cached = None
|
||||||
@@ -520,7 +520,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
self.hidden_size = vision_config.hidden_size
|
self.hidden_size = vision_config.hidden_size
|
||||||
self.num_heads = vision_config.num_heads
|
self.num_heads = vision_config.num_heads
|
||||||
|
|
||||||
# args for get_window_index
|
# args for get_window_index_thw
|
||||||
self.window_size = vision_config.window_size
|
self.window_size = vision_config.window_size
|
||||||
self.patch_size = vision_config.patch_size
|
self.patch_size = vision_config.patch_size
|
||||||
self.spatial_merge_size = vision_config.spatial_merge_size
|
self.spatial_merge_size = vision_config.spatial_merge_size
|
||||||
@@ -567,65 +567,71 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return self.patch_embed.proj.weight.device
|
return self.patch_embed.proj.weight.device
|
||||||
|
|
||||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
def rotary_pos_emb_thw(self, t, h, w):
|
||||||
pos_ids = []
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||||
for t, h, w in grid_thw:
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
hpos_ids = hpos_ids.reshape(
|
||||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
h // self.spatial_merge_size,
|
||||||
hpos_ids = hpos_ids.reshape(
|
self.spatial_merge_size,
|
||||||
h // self.spatial_merge_size,
|
w // self.spatial_merge_size,
|
||||||
self.spatial_merge_size,
|
self.spatial_merge_size,
|
||||||
w // self.spatial_merge_size,
|
).permute(0, 2, 1, 3).flatten()
|
||||||
self.spatial_merge_size,
|
wpos_ids = wpos_ids.reshape(
|
||||||
).permute(0, 2, 1, 3).flatten()
|
h // self.spatial_merge_size,
|
||||||
wpos_ids = wpos_ids.reshape(
|
self.spatial_merge_size,
|
||||||
h // self.spatial_merge_size,
|
w // self.spatial_merge_size,
|
||||||
self.spatial_merge_size,
|
self.spatial_merge_size,
|
||||||
w // self.spatial_merge_size,
|
).permute(0, 2, 1, 3).flatten()
|
||||||
self.spatial_merge_size,
|
pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
|
||||||
).permute(0, 2, 1, 3).flatten()
|
max_size = max(h, w)
|
||||||
pos_ids.append(
|
rotary_pos_emb_full = self.rotary_pos_emb(max_size)
|
||||||
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
||||||
pos_ids = torch.cat(pos_ids, dim=0)
|
|
||||||
max_grid_size = grid_thw[:, 1:].max()
|
|
||||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
||||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
|
rotary_pos_emb = rotary_pos_emb.reshape(
|
||||||
|
rotary_pos_emb.shape[0] // self.spatial_merge_unit,
|
||||||
|
self.spatial_merge_unit, -1)
|
||||||
|
|
||||||
return rotary_pos_emb
|
return rotary_pos_emb
|
||||||
|
|
||||||
def get_window_index(self, grid_thw):
|
def get_window_index_thw(self, grid_t, grid_h, grid_w):
|
||||||
window_index: list = []
|
|
||||||
cu_window_seqlens: list = [0]
|
|
||||||
window_index_id = 0
|
|
||||||
vit_merger_window_size = (self.window_size //
|
vit_merger_window_size = (self.window_size //
|
||||||
self.spatial_merge_size // self.patch_size)
|
self.spatial_merge_size // self.patch_size)
|
||||||
|
|
||||||
for grid_t, grid_h, grid_w in grid_thw:
|
llm_grid_h = grid_h // self.spatial_merge_size
|
||||||
llm_grid_h = grid_h // self.spatial_merge_size
|
llm_grid_w = grid_w // self.spatial_merge_size
|
||||||
llm_grid_w = grid_w // self.spatial_merge_size
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
grid_t, llm_grid_h, llm_grid_w)
|
||||||
grid_t, llm_grid_h, llm_grid_w)
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||||
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
|
||||||
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
|
index_padded = index_padded.reshape(grid_t, num_windows_h,
|
||||||
index_padded = index_padded.reshape(grid_t, num_windows_h,
|
vit_merger_window_size,
|
||||||
vit_merger_window_size,
|
num_windows_w,
|
||||||
num_windows_w,
|
vit_merger_window_size)
|
||||||
vit_merger_window_size)
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||||
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
|
||||||
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
|
vit_merger_window_size)
|
||||||
vit_merger_window_size)
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||||
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
index_padded = index_padded.reshape(-1)
|
||||||
index_padded = index_padded.reshape(-1)
|
index_new = index_padded[index_padded != -100]
|
||||||
index_new = index_padded[index_padded != -100]
|
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit
|
||||||
window_index.append(index_new + window_index_id)
|
cu_seqlens_tmp = cu_seqlens_tmp.to(dtype=torch.int32)
|
||||||
cu_seqlens_tmp = seqlens.cumsum(
|
cu_seqlens_tmp = torch.unique_consecutive(cu_seqlens_tmp)
|
||||||
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
|
||||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
return index_new, cu_seqlens_tmp
|
||||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
|
||||||
window_index = torch.cat(window_index, dim=0)
|
@lru_cache(maxsize=1024) # noqa: B019
|
||||||
return window_index, cu_window_seqlens
|
def get_rope_by_thw(self, t, h, w):
|
||||||
|
window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(
|
||||||
|
t, h, w)
|
||||||
|
rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
|
||||||
|
rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :]
|
||||||
|
rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1)
|
||||||
|
cu_seqlens_thw = torch.repeat_interleave(
|
||||||
|
torch.tensor([h * w], dtype=torch.int32), t)
|
||||||
|
return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw,
|
||||||
|
cu_seqlens_thw)
|
||||||
|
|
||||||
def compute_attn_mask_seqlen(
|
def compute_attn_mask_seqlen(
|
||||||
self,
|
self,
|
||||||
@@ -641,45 +647,74 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
grid_thw: torch.Tensor,
|
grid_thw: list[list[int]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# patchify
|
# patchify
|
||||||
|
seq_len, _ = x.size()
|
||||||
|
rotary_pos_emb = []
|
||||||
|
window_index: list = []
|
||||||
|
cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)]
|
||||||
|
cu_seqlens: list = []
|
||||||
|
|
||||||
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
||||||
hidden_states = self.patch_embed(hidden_states)
|
hidden_states = self.patch_embed(hidden_states)
|
||||||
|
|
||||||
# compute position embedding
|
window_index_id = 0
|
||||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
cu_window_seqlens_last = 0
|
||||||
|
for t, h, w in grid_thw:
|
||||||
|
t, h, w = int(t), int(h), int(w)
|
||||||
|
llm_h = h // self.spatial_merge_size
|
||||||
|
llm_w = w // self.spatial_merge_size
|
||||||
|
|
||||||
# windows attention
|
(
|
||||||
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
rotary_pos_emb_thw,
|
||||||
cu_window_seqlens = torch.tensor(
|
window_index_thw,
|
||||||
cu_window_seqlens,
|
cu_seqlens_window_thw,
|
||||||
device=hidden_states.device,
|
cu_seqlens_thw,
|
||||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
|
) = self.get_rope_by_thw(t, h, w)
|
||||||
|
|
||||||
|
window_index.append(window_index_thw + window_index_id)
|
||||||
|
window_index_id += (t * llm_h * llm_w)
|
||||||
|
|
||||||
|
cu_seqlens_window_thw = (cu_seqlens_window_thw +
|
||||||
|
cu_window_seqlens_last)
|
||||||
|
cu_window_seqlens_last = cu_seqlens_window_thw[-1]
|
||||||
|
cu_window_seqlens.append(cu_seqlens_window_thw)
|
||||||
|
|
||||||
|
rotary_pos_emb.append(rotary_pos_emb_thw)
|
||||||
|
|
||||||
|
cu_seqlens.append(cu_seqlens_thw)
|
||||||
|
|
||||||
|
rotary_pos_emb = torch.cat(rotary_pos_emb)
|
||||||
|
window_index = torch.cat(window_index)
|
||||||
|
cu_window_seqlens = torch.cat(cu_window_seqlens)
|
||||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||||
seq_len, _ = hidden_states.size()
|
cu_seqlens = torch.cat(cu_seqlens)
|
||||||
hidden_states = hidden_states.reshape(
|
cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
|
||||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
|
||||||
hidden_states = hidden_states[window_index, :, :]
|
|
||||||
hidden_states = hidden_states.reshape(seq_len, -1)
|
|
||||||
rotary_pos_emb = rotary_pos_emb.reshape(
|
|
||||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
|
||||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
|
||||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
|
||||||
# compute cu_seqlens
|
|
||||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
|
|
||||||
grid_thw[:, 0]).cumsum(
|
|
||||||
dim=0, dtype=torch.int32)
|
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||||
|
|
||||||
# transformers
|
# transformers
|
||||||
hidden_states = hidden_states.unsqueeze(1)
|
|
||||||
|
|
||||||
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
|
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
|
||||||
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(
|
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(
|
||||||
cu_seqlens)
|
cu_seqlens)
|
||||||
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
|
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
|
||||||
cu_window_seqlens)
|
cu_window_seqlens)
|
||||||
|
|
||||||
|
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
|
||||||
|
cu_window_seqlens = cu_window_seqlens.to(device=self.device,
|
||||||
|
non_blocking=True)
|
||||||
|
rotary_pos_emb = rotary_pos_emb.to(device=self.device,
|
||||||
|
non_blocking=True)
|
||||||
|
window_index = window_index.to(device=hidden_states.device,
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.reshape(
|
||||||
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||||
|
hidden_states = hidden_states[window_index, :, :]
|
||||||
|
hidden_states = hidden_states.reshape(seq_len, -1)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
|
|
||||||
for layer_num, blk in enumerate(self.blocks):
|
for layer_num, blk in enumerate(self.blocks):
|
||||||
if layer_num in self.fullatt_block_indexes:
|
if layer_num in self.fullatt_block_indexes:
|
||||||
cu_seqlens_now = cu_seqlens
|
cu_seqlens_now = cu_seqlens
|
||||||
@@ -932,12 +967,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
grid_thw = image_input["image_grid_thw"]
|
grid_thw = image_input["image_grid_thw"]
|
||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||||
else:
|
else:
|
||||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||||
|
|
||||||
# Split concatenated embeddings for each image item.
|
# Split concatenated embeddings for each image item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
@@ -951,13 +987,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
grid_thw = video_input["video_grid_thw"]
|
grid_thw = video_input["video_grid_thw"]
|
||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
|
||||||
if video_input["type"] == "video_embeds":
|
if video_input["type"] == "video_embeds":
|
||||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||||
else:
|
else:
|
||||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||||
self.visual.dtype)
|
self.visual.dtype)
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
video_embeds = self.visual(pixel_values_videos,
|
||||||
|
grid_thw=grid_thw_list)
|
||||||
|
|
||||||
# Split concatenated embeddings for each video item.
|
# Split concatenated embeddings for each video item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
|
|||||||
Reference in New Issue
Block a user