[ Misc ] Apply MoE Refactor to Deepseekv2 To Support Fp8 (#6417)

This commit is contained in:
Robert Shaw
2024-07-13 23:03:58 -04:00
committed by GitHub
parent eeceadaecc
commit fb6af8bc08
9 changed files with 222 additions and 136 deletions

View File

@@ -50,6 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
class Qwen2MoeMLP(nn.Module):
@@ -406,15 +407,13 @@ class Qwen2MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
else "experts.w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
for expert_id in range(self.config.num_experts) for shard_id,
weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
@@ -461,8 +460,20 @@ class Qwen2MoeForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",