[Model] Adding support for MSFT Phi-3.5-MoE (#7729)
Co-authored-by: Your Name <you@example.com> Co-authored-by: Zeqi Lin <zelin@microsoft.com> Co-authored-by: Zeqi Lin <Zeqi.Lin@microsoft.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -62,15 +62,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None) -> torch.Tensor:
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.forward(x=x,
|
||||
layer=layer,
|
||||
@@ -79,17 +82,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group)
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function)
|
||||
|
||||
def forward_cuda(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None) -> torch.Tensor:
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts)
|
||||
@@ -101,7 +108,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group)
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function)
|
||||
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -114,20 +122,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
raise NotImplementedError(
|
||||
"The CPU backend currently does not support MoE.")
|
||||
|
||||
def forward_tpu(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None) -> torch.Tensor:
|
||||
def forward_tpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
assert custom_routing_function is None
|
||||
return fused_moe(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@@ -172,6 +184,7 @@ class FusedMoE(torch.nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -190,6 +203,7 @@ class FusedMoE(torch.nn.Module):
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
@@ -390,7 +404,8 @@ class FusedMoE(torch.nn.Module):
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None):
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None):
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, grouped_topk)
|
||||
|
||||
@@ -405,11 +420,17 @@ class FusedMoE(torch.nn.Module):
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group)
|
||||
else:
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -426,7 +447,8 @@ class FusedMoE(torch.nn.Module):
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group)
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
|
||||
Reference in New Issue
Block a user