[ Misc ] Apply MoE Refactor to Deepseekv2 To Support Fp8 (#6417)
This commit is contained in:
@@ -50,6 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
@@ -406,15 +407,13 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = [
|
||||
# These are the weights for the experts
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
|
||||
else "experts.w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
||||
for expert_id in range(self.config.num_experts) for shard_id,
|
||||
weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
|
||||
]
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
@@ -461,8 +460,20 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
print_warning_once(
|
||||
"Found kv scale in the checkpoint "
|
||||
f"(e.g. {name}), but not found the expected "
|
||||
f"name in the model "
|
||||
f"(e.g. {remapped_kv_scale_name}). "
|
||||
"kv-scale is not loaded.")
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
|
||||
Reference in New Issue
Block a user