[ Misc ] Apply MoE Refactor to Deepseekv2 To Support Fp8 (#6417)

This commit is contained in:
Robert Shaw
2024-07-13 23:03:58 -04:00
committed by GitHub
parent eeceadaecc
commit fb6af8bc08
9 changed files with 222 additions and 136 deletions

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Optional
from typing import List, Optional, Tuple
import torch
@@ -29,7 +29,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
raise NotImplementedError
@@ -63,7 +66,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
return fused_moe(x,
layer.w13_weight,
@@ -71,7 +77,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
router_logits,
top_k,
renormalize=renormalize,
inplace=True)
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group)
class FusedMoE(torch.nn.Module):
@@ -104,6 +113,9 @@ class FusedMoE(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
):
@@ -119,6 +131,11 @@ class FusedMoE(torch.nn.Module):
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
@@ -140,9 +157,8 @@ class FusedMoE(torch.nn.Module):
shard_id: int, expert_id: int):
param_data = param.data
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
# Follow up PR to enable fp8 for other MoE models.
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
@@ -150,14 +166,21 @@ class FusedMoE(torch.nn.Module):
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
# Follow up PR to enable fp8 for other MoE models.
# Weight scales
elif "weight_scale" in weight_name:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
assert "w1" in weight_name or "w3" in weight_name
shard_id = 0 if "w1" in weight_name else 1
param_data[expert_id][shard_id] = loaded_weight
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
if shard_id == 0 or shard_id == 2:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == 0 else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
# shard_id 1 == down_proj / w2
else:
param_data[expert_id] = loaded_weight
# Weights
else:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition
@@ -188,10 +211,50 @@ class FusedMoE(torch.nn.Module):
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize)
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group)
if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states
@classmethod
def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, int]]:
gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
gate_down_up = [
ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
]
return [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_scale"
if weight_name in gate_up else "experts.w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
shard_id) for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_weight"
if weight_name in gate_up else "experts.w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
] + [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.a13_scale"
if weight_name in gate_up else "experts.a2_scale",
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
shard_id) for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
]