[Bugfix] Fix DSV32 weight loading (#38870)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
@@ -184,11 +184,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.model = DeepSeekMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
# Set MoE hyperparameters
|
||||
self.set_moe_parameters()
|
||||
self.is_fp4_ckpt = (
|
||||
self.quant_config is not None
|
||||
and self.quant_config.get_name() == "modelopt_fp4"
|
||||
)
|
||||
|
||||
def set_moe_parameters(self):
|
||||
self.expert_weights = []
|
||||
@@ -241,11 +246,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
# Fused indexer wk + weights_proj
|
||||
("wk_weights_proj", "wk", 0),
|
||||
("wk_weights_proj", "weights_proj", 1),
|
||||
]
|
||||
|
||||
if self.is_fp4_ckpt:
|
||||
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
|
||||
indexer_fused_mapping = [
|
||||
("wk_weights_proj", "wk", 0),
|
||||
("wk_weights_proj", "weights_proj", 1),
|
||||
]
|
||||
stacked_params_mapping.extend(indexer_fused_mapping)
|
||||
|
||||
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
|
||||
self,
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
|
||||
@@ -625,6 +625,11 @@ class Indexer(nn.Module):
|
||||
super().__init__()
|
||||
self.vllm_config = vllm_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.is_fp4_ckpt = (
|
||||
self.quant_config is not None
|
||||
and self.quant_config.get_name() == "modelopt_fp4"
|
||||
)
|
||||
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
|
||||
self.topk_tokens = config.index_topk
|
||||
self.n_head = config.index_n_heads # 64
|
||||
@@ -639,18 +644,36 @@ class Indexer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
)
|
||||
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
|
||||
# weights_proj does not get quantized, so we run both with quant_config=None
|
||||
# wk may be upcasted from the default quant; experiments show fusion is always
|
||||
# faster unless WK proj is in FP4, which is not the case for all known quants.
|
||||
self.wk_weights_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[self.head_dim, self.n_head],
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
disable_tp=True,
|
||||
prefix=f"{prefix}.wk_weights_proj",
|
||||
)
|
||||
if self.is_fp4_ckpt:
|
||||
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
|
||||
# weights_proj does not get quantized,
|
||||
# so we run both with quant_config=None
|
||||
# wk may be upcasted from the default quant;
|
||||
# experiments show fusion is always faster unless WK proj is in FP4,
|
||||
# which is not the case for all known quants.
|
||||
self.wk_weights_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[self.head_dim, self.n_head],
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
disable_tp=True,
|
||||
prefix=f"{prefix}.wk_weights_proj",
|
||||
)
|
||||
else:
|
||||
self.wk = ReplicatedLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wk",
|
||||
)
|
||||
self.weights_proj = ReplicatedLinear(
|
||||
hidden_size,
|
||||
self.n_head,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.weights_proj",
|
||||
)
|
||||
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
@@ -691,11 +714,14 @@ class Indexer(nn.Module):
|
||||
q_pe, q_nope = torch.split(
|
||||
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
|
||||
)
|
||||
|
||||
# Fused wk + weights_proj: one GEMM, then split
|
||||
kw, _ = self.wk_weights_proj(hidden_states)
|
||||
k = kw[:, : self.head_dim]
|
||||
weights_raw = kw[:, self.head_dim :]
|
||||
if self.is_fp4_ckpt:
|
||||
# Fused wk + weights_proj: one GEMM, then split
|
||||
kw, _ = self.wk_weights_proj(hidden_states)
|
||||
k = kw[:, : self.head_dim]
|
||||
weights = kw[:, self.head_dim :]
|
||||
else:
|
||||
k, _ = self.wk(hidden_states)
|
||||
weights, _ = self.weights_proj(hidden_states)
|
||||
|
||||
k = self.k_norm(k)
|
||||
k_pe, k_nope = torch.split(
|
||||
@@ -726,7 +752,7 @@ class Indexer(nn.Module):
|
||||
q_scale = q_scale.view(-1, self.n_head, 1)
|
||||
|
||||
weights = (
|
||||
weights_raw.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
|
||||
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
|
||||
)
|
||||
weights = weights.squeeze(-1)
|
||||
|
||||
@@ -1314,6 +1340,10 @@ class DeepseekV2ForCausalLM(
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.is_fp4_ckpt = (
|
||||
self.quant_config is not None
|
||||
and self.quant_config.get_name() == "modelopt_fp4"
|
||||
)
|
||||
|
||||
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
|
||||
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
|
||||
@@ -1439,12 +1469,13 @@ class DeepseekV2ForCausalLM(
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
|
||||
indexer_fused_mapping = [
|
||||
("wk_weights_proj", "wk", 0),
|
||||
("wk_weights_proj", "weights_proj", 1),
|
||||
]
|
||||
stacked_params_mapping.extend(indexer_fused_mapping)
|
||||
if self.is_fp4_ckpt:
|
||||
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
|
||||
indexer_fused_mapping = [
|
||||
("wk_weights_proj", "wk", 0),
|
||||
("wk_weights_proj", "weights_proj", 1),
|
||||
]
|
||||
stacked_params_mapping.extend(indexer_fused_mapping)
|
||||
|
||||
if self.use_mha:
|
||||
stacked_params_mapping.extend(mha_params_mapping)
|
||||
|
||||
Reference in New Issue
Block a user