[PluggableLayer][1/N] Define PluggableLayer (#32331)
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -8,15 +8,6 @@ This document will introduce how CustomOp works in vLLM and how to implement a n
|
|||||||
|
|
||||||
`CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively.
|
`CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively.
|
||||||
|
|
||||||
??? code
|
|
||||||
|
|
||||||
```python
|
|
||||||
class CustomOp(nn.Module):
|
|
||||||
|
|
||||||
op_registry: dict[str, type["CustomOp"]] = {}
|
|
||||||
op_registry_oot: dict[str, type["CustomOp"]] = {}
|
|
||||||
```
|
|
||||||
|
|
||||||
We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later.
|
We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later.
|
||||||
|
|
||||||
When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled (i.e., with `--compilation_config.custom_ops '["+op_name"]'`), it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method.
|
When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled (i.e., with `--compilation_config.custom_ops '["+op_name"]'`), it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method.
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from vllm.config import (
|
|||||||
get_cached_compilation_config,
|
get_cached_compilation_config,
|
||||||
set_current_vllm_config,
|
set_current_vllm_config,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp, op_registry
|
||||||
from vllm.model_executor.layers.activation import (
|
from vllm.model_executor.layers.activation import (
|
||||||
GeluAndMul,
|
GeluAndMul,
|
||||||
ReLUSquaredActivation,
|
ReLUSquaredActivation,
|
||||||
@@ -98,17 +98,17 @@ def test_enabled_ops(
|
|||||||
ops_enabled = [bool(x) for x in ops_enabled]
|
ops_enabled = [bool(x) for x in ops_enabled]
|
||||||
|
|
||||||
assert RMSNorm(1024).enabled() == ops_enabled[0]
|
assert RMSNorm(1024).enabled() == ops_enabled[0]
|
||||||
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
|
assert op_registry["rms_norm"].enabled() == ops_enabled[0]
|
||||||
|
|
||||||
assert SiluAndMul().enabled() == ops_enabled[1]
|
assert SiluAndMul().enabled() == ops_enabled[1]
|
||||||
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
|
assert op_registry["silu_and_mul"].enabled() == ops_enabled[1]
|
||||||
|
|
||||||
assert GeluAndMul().enabled() == ops_enabled[2]
|
assert GeluAndMul().enabled() == ops_enabled[2]
|
||||||
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
|
assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
|
||||||
|
|
||||||
# If registered, subclasses should follow their own name
|
# If registered, subclasses should follow their own name
|
||||||
assert Relu3().enabled() == ops_enabled[3]
|
assert Relu3().enabled() == ops_enabled[3]
|
||||||
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
|
assert op_registry["relu3"].enabled() == ops_enabled[3]
|
||||||
|
|
||||||
# Unregistered subclass
|
# Unregistered subclass
|
||||||
class SiluAndMul2(SiluAndMul):
|
class SiluAndMul2(SiluAndMul):
|
||||||
|
|||||||
@@ -11,6 +11,86 @@ from vllm.platforms import current_platform
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Dictionary of all custom ops (classes, indexed by registered name).
|
||||||
|
# To check if an op with a name is enabled, call .enabled() on the class.
|
||||||
|
# Examples:
|
||||||
|
# - MyOp.enabled()
|
||||||
|
# - op_registry["my_op"].enabled()
|
||||||
|
op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
||||||
|
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class PluggableLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Base class for pluggable layers.
|
||||||
|
|
||||||
|
A PluggableLayer is a *module-composing* abstraction: it may instantiate other
|
||||||
|
``torch.nn.Module`` objects as sub-layers, and its functionality depends on
|
||||||
|
these sub-layers following a generalized invocation sequence. Also, it is stateful
|
||||||
|
and may hold parameters or buffers.
|
||||||
|
|
||||||
|
Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform
|
||||||
|
``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement
|
||||||
|
of the entire layer class at instantiation time, allowing customized
|
||||||
|
initialization and submodule composition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
layer_class_name = cls.__name__
|
||||||
|
except AttributeError:
|
||||||
|
raise TypeError(
|
||||||
|
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
|
||||||
|
f"was not set, possibly because it was not decorated with "
|
||||||
|
f"@PluggableLayer.register, or it's the PluggableLayer itself."
|
||||||
|
) from None
|
||||||
|
|
||||||
|
if layer_class_name not in op_registry_oot:
|
||||||
|
layer_cls_to_instantiate = cls
|
||||||
|
else:
|
||||||
|
layer_cls_to_instantiate = op_registry_oot[layer_class_name]
|
||||||
|
logger.debug(
|
||||||
|
"Instantiating pluggable layer: %s using %s",
|
||||||
|
layer_class_name,
|
||||||
|
str(layer_cls_to_instantiate),
|
||||||
|
)
|
||||||
|
return super().__new__(layer_cls_to_instantiate)
|
||||||
|
|
||||||
|
# Decorator to register pluggable layers.
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str):
|
||||||
|
def decorator(op_cls):
|
||||||
|
assert name not in op_registry, f"Duplicate op name: {name}"
|
||||||
|
op_cls.name = name
|
||||||
|
op_registry[name] = op_cls
|
||||||
|
return op_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
# Decorator to register out-of-tree(oot) pluggable layers.
|
||||||
|
# For OOT pluggable layers:
|
||||||
|
# if in-tree layer class is registered with an oot_custom_layer,
|
||||||
|
# the oot_custom_layer will be used instead.
|
||||||
|
@classmethod
|
||||||
|
def register_oot(cls, _decorated_layer_cls=None, name: str | None = None):
|
||||||
|
def decorator(layer_cls):
|
||||||
|
reg_name = name if name is not None else cls.__name__
|
||||||
|
assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}"
|
||||||
|
layer_cls.name = reg_name
|
||||||
|
op_registry_oot[reg_name] = layer_cls
|
||||||
|
return layer_cls
|
||||||
|
|
||||||
|
if _decorated_layer_cls is None:
|
||||||
|
# Called with parentheses: @PluggableLayer.register_oot()
|
||||||
|
# or @PluggableLayer.register_oot(name="...")
|
||||||
|
return decorator
|
||||||
|
elif isinstance(_decorated_layer_cls, type): # Check if it's a class
|
||||||
|
# Called without parentheses: @PluggableLayer.register_oot
|
||||||
|
return decorator(_decorated_layer_cls)
|
||||||
|
else:
|
||||||
|
raise TypeError("Decorator can only be applied to classes.")
|
||||||
|
|
||||||
|
|
||||||
class CustomOp(nn.Module):
|
class CustomOp(nn.Module):
|
||||||
"""
|
"""
|
||||||
Base class for custom ops.
|
Base class for custom ops.
|
||||||
@@ -27,10 +107,10 @@ class CustomOp(nn.Module):
|
|||||||
f"@CustomOp.register, or it's the CustomOp base class itself."
|
f"@CustomOp.register, or it's the CustomOp base class itself."
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
if op_name not in cls.op_registry_oot:
|
if op_name not in op_registry_oot:
|
||||||
op_cls_to_instantiate = cls
|
op_cls_to_instantiate = cls
|
||||||
else:
|
else:
|
||||||
op_cls_to_instantiate = cls.op_registry_oot[op_name]
|
op_cls_to_instantiate = op_registry_oot[op_name]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Instantiating custom op: %s using %s",
|
"Instantiating custom op: %s using %s",
|
||||||
op_name,
|
op_name,
|
||||||
@@ -150,21 +230,13 @@ class CustomOp(nn.Module):
|
|||||||
|
|
||||||
return not count_none > 0 or count_all > 0
|
return not count_none > 0 or count_all > 0
|
||||||
|
|
||||||
# Dictionary of all custom ops (classes, indexed by registered name).
|
|
||||||
# To check if an op with a name is enabled, call .enabled() on the class.
|
|
||||||
# Examples:
|
|
||||||
# - MyOp.enabled()
|
|
||||||
# - op_registry["my_op"].enabled()
|
|
||||||
op_registry: dict[str, type["CustomOp"]] = {}
|
|
||||||
op_registry_oot: dict[str, type["CustomOp"]] = {}
|
|
||||||
|
|
||||||
# Decorator to register custom ops.
|
# Decorator to register custom ops.
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, name: str):
|
def register(cls, name: str):
|
||||||
def decorator(op_cls):
|
def decorator(op_cls):
|
||||||
assert name not in cls.op_registry, f"Duplicate op name: {name}"
|
assert name not in op_registry, f"Duplicate op name: {name}"
|
||||||
op_cls.name = name
|
op_cls.name = name
|
||||||
cls.op_registry[name] = op_cls
|
op_registry[name] = op_cls
|
||||||
return op_cls
|
return op_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -182,9 +254,9 @@ class CustomOp(nn.Module):
|
|||||||
def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
|
def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
|
||||||
def decorator(op_cls):
|
def decorator(op_cls):
|
||||||
reg_name = name if name is not None else cls.__name__
|
reg_name = name if name is not None else cls.__name__
|
||||||
assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}"
|
assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}"
|
||||||
op_cls.name = reg_name
|
op_cls.name = reg_name
|
||||||
cls.op_registry_oot[reg_name] = op_cls
|
op_registry_oot[reg_name] = op_cls
|
||||||
return op_cls
|
return op_cls
|
||||||
|
|
||||||
if _decorated_op_cls is None:
|
if _decorated_op_cls is None:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.attention.layer import MLAAttention
|
from vllm.attention.layer import MLAAttention
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import PluggableLayer
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -30,13 +30,13 @@ class MLAModules:
|
|||||||
|
|
||||||
|
|
||||||
# --8<-- [start:multi_head_latent_attention]
|
# --8<-- [start:multi_head_latent_attention]
|
||||||
@CustomOp.register("multi_head_latent_attention")
|
@PluggableLayer.register("multi_head_latent_attention")
|
||||||
class MultiHeadLatentAttentionWrapper(CustomOp):
|
class MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||||
"""MLA layer registered as CustomOp to allow OOT backends to add
|
"""Pluggable MLA layer which allows OOT backends to add
|
||||||
custom implementations of the outer MLA layer (including rope & o_proj).
|
custom implementations of the outer MLA layer (including rope & o_proj).
|
||||||
Note that currently MLA ignores the enable/disable mechanism of CustomOp
|
Note that currently oot platforms can still use CustomOp.register_oot to
|
||||||
because there is only one in-tree implementation in forward_native.
|
replace MLA layer entirly, although we use PluggableLayer to register
|
||||||
TODO: implement this with a new PluggableLayer mechanism.
|
this layer now.
|
||||||
|
|
||||||
This class takes positions and hidden_states as input.
|
This class takes positions and hidden_states as input.
|
||||||
The input tensors can either contain prefill tokens or decode tokens.
|
The input tensors can either contain prefill tokens or decode tokens.
|
||||||
@@ -110,7 +110,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
|||||||
|
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
def forward_native(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -174,6 +174,3 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_out)[0]
|
return self.o_proj(attn_out)[0]
|
||||||
|
|
||||||
def forward_cuda(self, *args, **kwargs):
|
|
||||||
return self.forward_native(*args, **kwargs)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user