[CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic (#29873)

Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
Shanshan Shen
2025-12-16 11:08:16 +08:00
committed by GitHub
parent ff21a0fc85
commit 3bd9c49158
14 changed files with 553 additions and 280 deletions

View File

@@ -22,7 +22,7 @@ from typing import Annotated, Literal
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops import rearrange
from transformers import BatchFeature, PretrainedConfig
from transformers.activations import GELUActivation
from transformers.modeling_outputs import (
@@ -47,7 +47,7 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
dispatch_rotary_emb_function,
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
@@ -130,47 +130,6 @@ def smart_resize(
return h_bar, w_bar
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = rotary_emb_function(t_, cos, sin).type_as(t)
return output
class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
@@ -609,6 +568,10 @@ class SiglipAttention(nn.Module):
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
seq_len, bs, _ = qkv.shape
@@ -651,7 +614,11 @@ class SiglipAttention(nn.Module):
if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
context_layer = self.attn(