[CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic (#29873)

Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
Shanshan Shen
2025-12-16 11:08:16 +08:00
committed by GitHub
parent ff21a0fc85
commit 3bd9c49158
14 changed files with 553 additions and 280 deletions

View File

@@ -6,7 +6,6 @@ within a vision language model."""
from collections.abc import Iterable
import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
from transformers import Siglip2VisionConfig
@@ -26,6 +25,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
@@ -146,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module):
return patch_embeds
# copy from flash_attn/layers/rotary.py
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
@@ -189,14 +157,20 @@ def apply_rotary_pos_emb(
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend and current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
if is_flash_attn_backend and not current_platform.is_cuda():
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
else:
apply_rotary_emb_func = apply_rotary_emb_torch
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
apply_rotary_emb_func = apply_rotary_emb.forward_native
q_embed = apply_rotary_emb_func(q, cos, sin)
k_embed = apply_rotary_emb_func(k, cos, sin)
return q_embed, k_embed