[custom_op][vllm-plugin] update custom_op class to use op_registry (#19164)
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
This commit is contained in:
@@ -10,5 +10,7 @@ setup(
|
||||
entry_points={
|
||||
'vllm.platform_plugins': [
|
||||
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
|
||||
]
|
||||
],
|
||||
"vllm.general_plugins":
|
||||
["dummy_custom_ops = vllm_add_dummy_platform:register_ops"],
|
||||
})
|
||||
|
||||
@@ -6,3 +6,7 @@ from typing import Optional
|
||||
|
||||
def dummy_platform_plugin() -> Optional[str]:
|
||||
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
|
||||
|
||||
|
||||
def register_ops():
|
||||
import vllm_add_dummy_platform.dummy_custom_ops # noqa
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionBackend)
|
||||
|
||||
|
||||
class DummyAttentionBackend(FlashAttentionBackend):
|
||||
class DummyAttentionBackend(PlaceholderAttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
# Register CustomRotaryEmbedding to CustomOP.
|
||||
@RotaryEmbedding.register_oot
|
||||
class DummyRotaryEmbedding(RotaryEmbedding):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.addition_config = True
|
||||
|
||||
def forward_oot(self, *args,
|
||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return super().forward_oot(*args, **kwargs)
|
||||
@@ -1,12 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.interface import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
from vllm import envs
|
||||
|
||||
|
||||
class DummyPlatform(CudaPlatform):
|
||||
class DummyPlatform(Platform):
|
||||
_enum = PlatformEnum.OOT
|
||||
device_name = "DummyDevice"
|
||||
device_type: str = "privateuseone"
|
||||
dispatch_key: str = "PrivateUse1"
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = vllm_config.compilation_config
|
||||
# Activate custom ops for v1.
|
||||
compilation_config.custom_ops = ["all"]
|
||||
|
||||
def get_attn_backend_cls(self, backend_name, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla):
|
||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||
Reference in New Issue
Block a user