[CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic (#29873)
Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: gcanlin <canlinguosdu@gmail.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -59,8 +59,7 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
apply_rotary_emb_torch,
|
||||
dispatch_rotary_emb_function,
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@@ -280,16 +279,6 @@ class Qwen2VisionMLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
rotary_emb_function = dispatch_rotary_emb_function(
|
||||
default=partial(apply_rotary_emb_torch, is_neox_style=True)
|
||||
)
|
||||
output = rotary_emb_function(t, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2VisionAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -341,6 +330,8 @@ class Qwen2VisionAttention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@@ -387,8 +378,10 @@ class Qwen2VisionAttention(nn.Module):
|
||||
|
||||
# [2 * b, s, heads, head_dim]
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(
|
||||
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin,
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user