[PluggableLayer][1/N] Define PluggableLayer (Fix ci) (#32744)
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from vllm.config import (
|
||||
get_cached_compilation_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.custom_op import CustomOp, op_registry
|
||||
from vllm.model_executor.layers.activation import (
|
||||
GeluAndMul,
|
||||
ReLUSquaredActivation,
|
||||
@@ -98,17 +98,17 @@ def test_enabled_ops(
|
||||
ops_enabled = [bool(x) for x in ops_enabled]
|
||||
|
||||
assert RMSNorm(1024).enabled() == ops_enabled[0]
|
||||
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
|
||||
assert op_registry["rms_norm"].enabled() == ops_enabled[0]
|
||||
|
||||
assert SiluAndMul().enabled() == ops_enabled[1]
|
||||
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
|
||||
assert op_registry["silu_and_mul"].enabled() == ops_enabled[1]
|
||||
|
||||
assert GeluAndMul().enabled() == ops_enabled[2]
|
||||
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
|
||||
assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
|
||||
|
||||
# If registered, subclasses should follow their own name
|
||||
assert Relu3().enabled() == ops_enabled[3]
|
||||
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
|
||||
assert op_registry["relu3"].enabled() == ops_enabled[3]
|
||||
|
||||
# Unregistered subclass
|
||||
class SiluAndMul2(SiluAndMul):
|
||||
|
||||
Reference in New Issue
Block a user