[PluggableLayer][2/N] Apply PluggableLayer to linear layers (#33152)
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@@ -239,7 +239,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
|
||||
class LinearBase(CustomOp):
|
||||
class LinearBase(PluggableLayer):
|
||||
"""Base linear layer.
|
||||
|
||||
Args:
|
||||
@@ -294,7 +294,7 @@ class LinearBase(CustomOp):
|
||||
|
||||
|
||||
# --8<-- [start:replicated_linear]
|
||||
@CustomOp.register("replicated_linear")
|
||||
@PluggableLayer.register("replicated_linear")
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""Replicated linear layer.
|
||||
|
||||
@@ -414,7 +414,7 @@ class ReplicatedLinear(LinearBase):
|
||||
|
||||
|
||||
# --8<-- [start:column_parallel_linear]
|
||||
@CustomOp.register("column_parallel_linear")
|
||||
@PluggableLayer.register("column_parallel_linear")
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
@@ -1273,7 +1273,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
|
||||
# --8<-- [start:row_parallel_linear]
|
||||
@CustomOp.register("row_parallel_linear")
|
||||
@PluggableLayer.register("row_parallel_linear")
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user