[custom_op][vllm-plugin] update custom_op class to use op_registry (#19164)
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
@@ -16,6 +18,24 @@ class CustomOp(nn.Module):
|
||||
Dispatches the forward method to the appropriate backend.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
try:
|
||||
op_name = cls.__name__
|
||||
except AttributeError:
|
||||
raise TypeError(
|
||||
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
|
||||
f"was not set, possibly because it was not decorated with "
|
||||
f"@CustomOp.register, or it's the CustomOp base class itself."
|
||||
) from None
|
||||
|
||||
if op_name not in cls.op_registry_oot:
|
||||
op_cls_to_instantiate = cls
|
||||
else:
|
||||
op_cls_to_instantiate = cls.op_registry_oot[op_name]
|
||||
logger.debug("Instantiating custom op: %s using %s", op_name,
|
||||
str(op_cls_to_instantiate))
|
||||
return super().__new__(op_cls_to_instantiate)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._forward_method = self.dispatch_forward()
|
||||
@@ -138,6 +158,7 @@ class CustomOp(nn.Module):
|
||||
# - MyOp.enabled()
|
||||
# - op_registry["my_op"].enabled()
|
||||
op_registry: dict[str, type['CustomOp']] = {}
|
||||
op_registry_oot: dict[str, type['CustomOp']] = {}
|
||||
|
||||
# Decorator to register custom ops.
|
||||
@classmethod
|
||||
@@ -150,3 +171,38 @@ class CustomOp(nn.Module):
|
||||
return op_cls
|
||||
|
||||
return decorator
|
||||
|
||||
# Decorator to register out-of-tree(oot) custom ops.
|
||||
# For OOT custom ops:
|
||||
# if in-tree layer class is registered with an oot_custom_op layer,
|
||||
# the oot_custom_op layer will be used instead.
|
||||
# Example:
|
||||
# - @UnquantizedFusedMoEMethod.register_oot
|
||||
# class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
|
||||
# or
|
||||
# - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
|
||||
@classmethod
|
||||
def register_oot(cls, _decorated_op_cls=None, name: Optional[str] = None):
|
||||
|
||||
def decorator(op_cls):
|
||||
reg_name = name if name is not None else cls.__name__
|
||||
assert reg_name not in cls.op_registry_oot, \
|
||||
f"Duplicate op name: {reg_name}"
|
||||
op_cls.name = reg_name
|
||||
cls.op_registry_oot[reg_name] = op_cls
|
||||
return op_cls
|
||||
|
||||
if _decorated_op_cls is None:
|
||||
# Called with parentheses: @CustomOP.register_oot()
|
||||
# or @CustomOP.register_oot(name="...")
|
||||
# So, _decorated_op_cls is None.
|
||||
# We return the actual decorator function.
|
||||
return decorator
|
||||
elif isinstance(_decorated_op_cls, type): # Check if it's a class
|
||||
# Called without parentheses: @CustomOP.register_oot
|
||||
# The first argument is the class itself.
|
||||
# We call the 'decorator' function immediately with the class.
|
||||
return decorator(_decorated_op_cls)
|
||||
else:
|
||||
# Handle other unexpected cases if necessary
|
||||
raise TypeError("Decorator can only be applied to classes.")
|
||||
|
||||
Reference in New Issue
Block a user