[Misc] Add CustomOp Interface to UnquantizedFusedMoEMethod (#6289)
This commit is contained in:
@@ -7,7 +7,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@@ -36,7 +36,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@@ -61,19 +61,37 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.forward(x, layer.w13_weight, layer.w2_weight,
|
||||
router_logits, top_k, renormalize,
|
||||
use_grouped_topk, num_expert_group, topk_group)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
|
||||
return fused_moe(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
w1,
|
||||
w2,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
@@ -82,6 +100,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group)
|
||||
|
||||
def forward_cpu(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"The CPU backend currently does not support MoE.")
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
"""FusedMoE layer for MoE models.
|
||||
|
||||
Reference in New Issue
Block a user