[Model][Perf] Use cos and sin cache in QwenVL (#28798)
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
This commit is contained in:
@@ -65,6 +65,7 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -341,7 +342,8 @@ class Glm4vVisionAttention(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||
seqlens: list[int] | None = None, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
@@ -353,10 +355,12 @@ class Glm4vVisionAttention(nn.Module):
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
|
||||
if rotary_pos_emb is not None:
|
||||
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
|
||||
# [2 * b, s, heads, head_dim]
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(
|
||||
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
@@ -454,14 +458,16 @@ class Glm4vVisionBlock(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: int | None = None, # Only used for Flash Attention
|
||||
seqlens: list[int] | None = None, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
x_attn = self.attn(
|
||||
self.norm1(x),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
@@ -660,44 +666,6 @@ class Glm4vVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class Glm4vVisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self._seq_len_cached = 0
|
||||
self._freqs_cached = None
|
||||
|
||||
def update_freqs_cache(self, seqlen: int) -> None:
|
||||
if seqlen > self._seq_len_cached:
|
||||
seqlen *= 2
|
||||
self._seq_len_cached = seqlen
|
||||
self.inv_freq = 1.0 / (
|
||||
self.theta
|
||||
** (
|
||||
torch.arange(
|
||||
0,
|
||||
self.dim,
|
||||
2,
|
||||
dtype=torch.float,
|
||||
device=self.inv_freq.device,
|
||||
)
|
||||
/ self.dim
|
||||
)
|
||||
)
|
||||
seq = torch.arange(
|
||||
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||
)
|
||||
freqs = torch.outer(seq, self.inv_freq)
|
||||
self._freqs_cached = freqs
|
||||
|
||||
def forward(self, seqlen: int) -> torch.Tensor:
|
||||
self.update_freqs_cache(seqlen)
|
||||
return self._freqs_cached[:seqlen]
|
||||
|
||||
|
||||
class Glm4vVisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -731,7 +699,13 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
|
||||
norm_layer = partial(RMSNorm, eps=norm_eps)
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
|
||||
self.rotary_pos_emb = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim // 2,
|
||||
max_position=8192,
|
||||
base=10000.0,
|
||||
is_neox_style=True,
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Glm4vVisionBlock(
|
||||
@@ -789,7 +763,9 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
def device(self) -> torch.device:
|
||||
return self.patch_embed.proj.weight.device
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
def rot_pos_emb(
|
||||
self, grid_thw: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
@@ -817,9 +793,18 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb, pos_ids
|
||||
|
||||
# Use pre-computed cos_sin_cache from RotaryEmbedding
|
||||
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
|
||||
|
||||
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
|
||||
cos_w = cos[pos_ids[:, 1]]
|
||||
sin_h = sin[pos_ids[:, 0]]
|
||||
sin_w = sin[pos_ids[:, 1]]
|
||||
|
||||
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
|
||||
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
|
||||
return cos_combined, sin_combined, pos_ids
|
||||
|
||||
def compute_attn_mask_seqlen(
|
||||
self,
|
||||
@@ -848,7 +833,9 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
x = self.post_conv_layernorm(x)
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
|
||||
rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb(
|
||||
grid_thw
|
||||
)
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
@@ -867,7 +854,8 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
x = blk(
|
||||
x,
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user