[Platform] Custom ops support for FusedMoe (#22509)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
@@ -226,7 +227,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
|
||||
class LinearBase(torch.nn.Module):
|
||||
class LinearBase(CustomOp):
|
||||
"""Base linear layer.
|
||||
|
||||
Args:
|
||||
@@ -269,12 +270,8 @@ class LinearBase(torch.nn.Module):
|
||||
prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@CustomOp.register("replicated_linear")
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""Replicated linear layer.
|
||||
|
||||
@@ -443,6 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
||||
param[shard_offset:shard_offset + shard_size] = loaded_weight
|
||||
|
||||
|
||||
@CustomOp.register("column_parallel_linear")
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
@@ -1229,6 +1227,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@CustomOp.register("row_parallel_linear")
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
@@ -1405,6 +1404,7 @@ class RowParallelLinear(LinearBase):
|
||||
return s
|
||||
|
||||
|
||||
@CustomOp.register("qkv_cross_parallel_linear")
|
||||
class QKVCrossParallelLinear(LinearBase):
|
||||
"""Linear layers for efficient cross-attention's QKV transformation.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user