[PluggableLayer][2/N] Apply PluggableLayer to linear layers (#33152)

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2026-01-29 16:53:15 +08:00
committed by GitHub
parent 3bba2edb0f
commit 08b1195e62

View File

@@ -17,7 +17,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
@@ -239,7 +239,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
class LinearBase(CustomOp):
class LinearBase(PluggableLayer):
"""Base linear layer.
Args:
@@ -294,7 +294,7 @@ class LinearBase(CustomOp):
# --8<-- [start:replicated_linear]
@CustomOp.register("replicated_linear")
@PluggableLayer.register("replicated_linear")
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
@@ -414,7 +414,7 @@ class ReplicatedLinear(LinearBase):
# --8<-- [start:column_parallel_linear]
@CustomOp.register("column_parallel_linear")
@PluggableLayer.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
@@ -1273,7 +1273,7 @@ class QKVParallelLinear(ColumnParallelLinear):
# --8<-- [start:row_parallel_linear]
@CustomOp.register("row_parallel_linear")
@PluggableLayer.register("row_parallel_linear")
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.