[ Misc ] Apply MoE Refactor to Deepseekv2 To Support Fp8 (#6417)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -29,7 +29,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True) -> torch.Tensor:
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -63,7 +66,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True) -> torch.Tensor:
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||
|
||||
return fused_moe(x,
|
||||
layer.w13_weight,
|
||||
@@ -71,7 +77,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
inplace=True)
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group)
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
@@ -104,6 +113,9 @@ class FusedMoE(torch.nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
):
|
||||
@@ -119,6 +131,11 @@ class FusedMoE(torch.nn.Module):
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
@@ -140,9 +157,8 @@ class FusedMoE(torch.nn.Module):
|
||||
shard_id: int, expert_id: int):
|
||||
param_data = param.data
|
||||
|
||||
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
|
||||
# Follow up PR to enable fp8 for other MoE models.
|
||||
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
if "input_scale" in weight_name:
|
||||
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
@@ -150,14 +166,21 @@ class FusedMoE(torch.nn.Module):
|
||||
f"must be equal. But got {param_data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
param_data[expert_id] = loaded_weight
|
||||
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
|
||||
# Follow up PR to enable fp8 for other MoE models.
|
||||
# Weight scales
|
||||
elif "weight_scale" in weight_name:
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
assert "w1" in weight_name or "w3" in weight_name
|
||||
shard_id = 0 if "w1" in weight_name else 1
|
||||
param_data[expert_id][shard_id] = loaded_weight
|
||||
# If we are in merged column case (gate_up_proj)
|
||||
# shard_id 0 == gate_proj / w1
|
||||
# shard_id 2 == up_proj / w3
|
||||
if shard_id == 0 or shard_id == 2:
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
idx = 0 if shard_id == 0 else 1
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
# shard_id 1 == down_proj / w2
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
# Weights
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.intermediate_size_per_partition
|
||||
@@ -188,10 +211,50 @@ class FusedMoE(torch.nn.Module):
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize)
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int) -> List[Tuple[str, str, int, int]]:
|
||||
|
||||
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||
gate_down_up = [
|
||||
ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
|
||||
]
|
||||
|
||||
return [
|
||||
# These are the weight scales for the experts
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_scale"
|
||||
if weight_name in gate_up else "experts.w2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
|
||||
shard_id) for expert_id in range(num_experts)
|
||||
for shard_id, weight_name in enumerate(gate_down_up)
|
||||
] + [
|
||||
# These are the weights for the experts
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_weight"
|
||||
if weight_name in gate_up else "experts.w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
||||
for expert_id in range(num_experts)
|
||||
for shard_id, weight_name in enumerate(gate_down_up)
|
||||
] + [
|
||||
# These are the weight scales for the experts
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.a13_scale"
|
||||
if weight_name in gate_up else "experts.a2_scale",
|
||||
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
|
||||
shard_id) for expert_id in range(num_experts)
|
||||
for shard_id, weight_name in enumerate(gate_down_up)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user