[CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic (#29873)
Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: gcanlin <canlinguosdu@gmail.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -60,6 +60,9 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
|
||||
@@ -95,7 +98,6 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
|
||||
from .qwen2_vl import (
|
||||
Qwen2VLMultiModalProcessor,
|
||||
Qwen2VLProcessingInfo,
|
||||
apply_rotary_pos_emb_vision,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
@@ -353,6 +355,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -378,8 +382,10 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
qk_reshaped = einops.rearrange(
|
||||
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
|
||||
)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(
|
||||
qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_reshaped,
|
||||
rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin,
|
||||
)
|
||||
qk_rotated = qk_rotated.view(
|
||||
2,
|
||||
|
||||
Reference in New Issue
Block a user