diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f5aec80d3..bb44e08a1 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -17,7 +17,7 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.logger import init_logger -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -239,7 +239,7 @@ class UnquantizedLinearMethod(LinearMethodBase): return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) -class LinearBase(CustomOp): +class LinearBase(PluggableLayer): """Base linear layer. Args: @@ -294,7 +294,7 @@ class LinearBase(CustomOp): # --8<-- [start:replicated_linear] -@CustomOp.register("replicated_linear") +@PluggableLayer.register("replicated_linear") class ReplicatedLinear(LinearBase): """Replicated linear layer. @@ -414,7 +414,7 @@ class ReplicatedLinear(LinearBase): # --8<-- [start:column_parallel_linear] -@CustomOp.register("column_parallel_linear") +@PluggableLayer.register("column_parallel_linear") class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -1273,7 +1273,7 @@ class QKVParallelLinear(ColumnParallelLinear): # --8<-- [start:row_parallel_linear] -@CustomOp.register("row_parallel_linear") +@PluggableLayer.register("row_parallel_linear") class RowParallelLinear(LinearBase): """Linear layer with row parallelism.