[custom_op][vllm-plugin] update custom_op class to use op_registry (#19164)

Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
This commit is contained in:
Chendi.Xue
2025-06-20 09:44:56 -05:00
committed by GitHub
parent f1e840e842
commit 7e8977fcd4
7 changed files with 120 additions and 6 deletions

View File

@@ -10,5 +10,7 @@ setup(
entry_points={
'vllm.platform_plugins': [
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
]
],
"vllm.general_plugins":
["dummy_custom_ops = vllm_add_dummy_platform:register_ops"],
})

View File

@@ -6,3 +6,7 @@ from typing import Optional
def dummy_platform_plugin() -> Optional[str]:
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
def register_ops():
import vllm_add_dummy_platform.dummy_custom_ops # noqa

View File

@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.flash_attn import FlashAttentionBackend
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
class DummyAttentionBackend(FlashAttentionBackend):
class DummyAttentionBackend(PlaceholderAttentionBackend):
@staticmethod
def get_name() -> str:

View File

@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
# Register CustomRotaryEmbedding to CustomOP.
@RotaryEmbedding.register_oot
class DummyRotaryEmbedding(RotaryEmbedding):
"""Original rotary positional embedding."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.addition_config = True
def forward_oot(self, *args,
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
return super().forward_oot(*args, **kwargs)

View File

@@ -1,12 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.interface import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
from vllm import envs
class DummyPlatform(CudaPlatform):
class DummyPlatform(Platform):
_enum = PlatformEnum.OOT
device_name = "DummyDevice"
device_type: str = "privateuseone"
dispatch_key: str = "PrivateUse1"
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if envs.VLLM_USE_V1:
compilation_config = vllm_config.compilation_config
# Activate custom ops for v1.
compilation_config.custom_ops = ["all"]
def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501