[PluggableLayer][2/N] Apply PluggableLayer to linear layers (#33152)
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from vllm.distributed import (
|
|||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import PluggableLayer
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
@@ -239,7 +239,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||||
|
|
||||||
|
|
||||||
class LinearBase(CustomOp):
|
class LinearBase(PluggableLayer):
|
||||||
"""Base linear layer.
|
"""Base linear layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -294,7 +294,7 @@ class LinearBase(CustomOp):
|
|||||||
|
|
||||||
|
|
||||||
# --8<-- [start:replicated_linear]
|
# --8<-- [start:replicated_linear]
|
||||||
@CustomOp.register("replicated_linear")
|
@PluggableLayer.register("replicated_linear")
|
||||||
class ReplicatedLinear(LinearBase):
|
class ReplicatedLinear(LinearBase):
|
||||||
"""Replicated linear layer.
|
"""Replicated linear layer.
|
||||||
|
|
||||||
@@ -414,7 +414,7 @@ class ReplicatedLinear(LinearBase):
|
|||||||
|
|
||||||
|
|
||||||
# --8<-- [start:column_parallel_linear]
|
# --8<-- [start:column_parallel_linear]
|
||||||
@CustomOp.register("column_parallel_linear")
|
@PluggableLayer.register("column_parallel_linear")
|
||||||
class ColumnParallelLinear(LinearBase):
|
class ColumnParallelLinear(LinearBase):
|
||||||
"""Linear layer with column parallelism.
|
"""Linear layer with column parallelism.
|
||||||
|
|
||||||
@@ -1273,7 +1273,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
|
|
||||||
|
|
||||||
# --8<-- [start:row_parallel_linear]
|
# --8<-- [start:row_parallel_linear]
|
||||||
@CustomOp.register("row_parallel_linear")
|
@PluggableLayer.register("row_parallel_linear")
|
||||||
class RowParallelLinear(LinearBase):
|
class RowParallelLinear(LinearBase):
|
||||||
"""Linear layer with row parallelism.
|
"""Linear layer with row parallelism.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user