[CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic (#29873)
Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: gcanlin <canlinguosdu@gmail.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -6,7 +6,6 @@ within a vision language model."""
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers import Siglip2VisionConfig
|
||||
@@ -26,6 +25,9 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -146,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module):
|
||||
return patch_embeds
|
||||
|
||||
|
||||
# copy from flash_attn/layers/rotary.py
|
||||
def rotate_half(x, interleaved=False):
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(
|
||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
sin = repeat(
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
return torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@@ -189,14 +157,20 @@ def apply_rotary_pos_emb(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
if is_flash_attn_backend and current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
apply_rotary_emb_func = apply_rotary_emb
|
||||
apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
if is_flash_attn_backend and not current_platform.is_cuda():
|
||||
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
|
||||
else:
|
||||
apply_rotary_emb_func = apply_rotary_emb_torch
|
||||
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
|
||||
apply_rotary_emb_func = apply_rotary_emb.forward_native
|
||||
|
||||
q_embed = apply_rotary_emb_func(q, cos, sin)
|
||||
k_embed = apply_rotary_emb_func(k, cos, sin)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user