[Model][QwenVL] Replace torch.repeat_interleave with faster np.repeat (#28964)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
@@ -128,12 +128,7 @@ def batch_make_image_embeddings(
|
|||||||
visual = model.visual
|
visual = model.visual
|
||||||
|
|
||||||
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
|
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
|
||||||
image_grid_thw_on_device = image_grid_thw.to(
|
return visual(pixel_values_on_device, grid_thw=image_grid_thw).cpu()
|
||||||
visual.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
return visual(
|
|
||||||
pixel_values_on_device, grid_thw=image_grid_thw_on_device
|
|
||||||
).cpu()
|
|
||||||
|
|
||||||
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
||||||
|
|
||||||
@@ -217,12 +212,7 @@ def batch_make_video_embeddings(
|
|||||||
visual = model.visual
|
visual = model.visual
|
||||||
|
|
||||||
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
|
pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
|
||||||
video_grid_thw_on_device = video_grid_thw.to(
|
return visual(pixel_values_on_device, grid_thw=video_grid_thw).cpu()
|
||||||
visual.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
return visual(
|
|
||||||
pixel_values_on_device, grid_thw=video_grid_thw_on_device
|
|
||||||
).cpu()
|
|
||||||
|
|
||||||
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from collections.abc import Callable, Iterable, Mapping, Sequence
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Annotated, Any, Literal, TypeAlias
|
from typing import Annotated, Any, Literal, TypeAlias
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -751,25 +752,27 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
if isinstance(grid_thw, list):
|
if isinstance(grid_thw, list):
|
||||||
grid_thw_list = grid_thw
|
grid_thw_list = grid_thw
|
||||||
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
|
grid_thw = np.array(grid_thw, dtype=np.int32)
|
||||||
else:
|
else:
|
||||||
grid_thw_list = grid_thw.tolist()
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
grid_thw = grid_thw.numpy()
|
||||||
|
|
||||||
# compute position embedding
|
# compute position embedding
|
||||||
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
|
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
|
||||||
|
|
||||||
# compute cu_seqlens
|
# compute cu_seqlens
|
||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
axis=0, dtype=np.int32
|
||||||
).cumsum(dim=0, dtype=torch.int32)
|
)
|
||||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||||
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
cu_seqlens = torch.from_numpy(cu_seqlens)
|
||||||
|
|
||||||
# transformers
|
# transformers
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
|
|
||||||
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
|
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
|
||||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
|
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
x = blk(
|
x = blk(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@@ -553,18 +553,20 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
if isinstance(grid_thw, list):
|
if isinstance(grid_thw, list):
|
||||||
grid_thw_list = grid_thw
|
grid_thw_list = grid_thw
|
||||||
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
|
grid_thw = np.array(grid_thw, dtype=np.int32)
|
||||||
else:
|
else:
|
||||||
grid_thw_list = grid_thw.tolist()
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
grid_thw = grid_thw.numpy()
|
||||||
|
|
||||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
|
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
|
||||||
hidden_states = hidden_states + pos_embeds
|
hidden_states = hidden_states + pos_embeds
|
||||||
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
|
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
|
||||||
|
|
||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
axis=0, dtype=np.int32
|
||||||
).cumsum(dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
|
)
|
||||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||||
|
cu_seqlens = torch.from_numpy(cu_seqlens)
|
||||||
|
|
||||||
hidden_states = hidden_states.unsqueeze(1)
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
|
|||||||
Reference in New Issue
Block a user