[CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic (#29873)

Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
Shanshan Shen
2025-12-16 11:08:16 +08:00
committed by GitHub
parent ff21a0fc85
commit 3bd9c49158
14 changed files with 553 additions and 280 deletions

View File

@@ -65,6 +65,9 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -95,7 +98,7 @@ from .interfaces import (
SupportsMultiModal,
SupportsPP,
)
from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision
from .qwen2_vl import _create_qwen2vl_field_factory
from .utils import (
AutoWeightsLoader,
WeightsMapper,
@@ -304,6 +307,8 @@ class Glm4vVisionAttention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@@ -339,8 +344,10 @@ class Glm4vVisionAttention(nn.Module):
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
)
q, k = torch.chunk(qk_rotated, 2, dim=0)