[Model] Adding support for MSFT Phi-3.5-MoE (#7729)

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Zeqi Lin <zelin@microsoft.com>
Co-authored-by: Zeqi Lin <Zeqi.Lin@microsoft.com>
This commit is contained in:
Wenxiang
2024-08-31 03:42:57 +08:00
committed by GitHub
parent 2684efc467
commit 1248e8506a
13 changed files with 1254 additions and 81 deletions

View File

@@ -1,6 +1,6 @@
from abc import abstractmethod
from enum import Enum
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple
import torch
@@ -62,15 +62,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
@@ -79,17 +82,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group)
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
def forward_cuda(self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:
def forward_cuda(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
@@ -101,7 +108,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group)
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
@@ -114,20 +122,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")
def forward_tpu(self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:
def forward_tpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
return fused_moe(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -172,6 +184,7 @@ class FusedMoE(torch.nn.Module):
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
):
super().__init__()
@@ -190,6 +203,7 @@ class FusedMoE(torch.nn.Module):
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
@@ -390,7 +404,8 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk)
@@ -405,11 +420,17 @@ class FusedMoE(torch.nn.Module):
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group)
else:
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
return topk_weights, topk_ids
@@ -426,7 +447,8 @@ class FusedMoE(torch.nn.Module):
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group)
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function)
if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(