[PluggableLayer][2/N] Apply PluggableLayer to linear layers (#33152)

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2026-01-29 16:53:15 +08:00
committed by GitHub
parent 3bba2edb0f
commit 08b1195e62

View File

@@ -17,7 +17,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
@@ -239,7 +239,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
class LinearBase(CustomOp): class LinearBase(PluggableLayer):
"""Base linear layer. """Base linear layer.
Args: Args:
@@ -294,7 +294,7 @@ class LinearBase(CustomOp):
# --8<-- [start:replicated_linear] # --8<-- [start:replicated_linear]
@CustomOp.register("replicated_linear") @PluggableLayer.register("replicated_linear")
class ReplicatedLinear(LinearBase): class ReplicatedLinear(LinearBase):
"""Replicated linear layer. """Replicated linear layer.
@@ -414,7 +414,7 @@ class ReplicatedLinear(LinearBase):
# --8<-- [start:column_parallel_linear] # --8<-- [start:column_parallel_linear]
@CustomOp.register("column_parallel_linear") @PluggableLayer.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase): class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
@@ -1273,7 +1273,7 @@ class QKVParallelLinear(ColumnParallelLinear):
# --8<-- [start:row_parallel_linear] # --8<-- [start:row_parallel_linear]
@CustomOp.register("row_parallel_linear") @PluggableLayer.register("row_parallel_linear")
class RowParallelLinear(LinearBase): class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism. """Linear layer with row parallelism.