diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 87364b1f8..126efb6f8 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -184,11 +184,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config self.model = DeepSeekMultiTokenPredictor( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) # Set MoE hyperparameters self.set_moe_parameters() + self.is_fp4_ckpt = ( + self.quant_config is not None + and self.quant_config.get_name() == "modelopt_fp4" + ) def set_moe_parameters(self): self.expert_weights = [] @@ -241,11 +246,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ("gate_up_proj", "up_proj", 1), ("fused_qkv_a_proj", "q_a_proj", 0), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), - # Fused indexer wk + weights_proj - ("wk_weights_proj", "wk", 0), - ("wk_weights_proj", "weights_proj", 1), ] + if self.is_fp4_ckpt: + # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) + indexer_fused_mapping = [ + ("wk_weights_proj", "wk", 0), + ("wk_weights_proj", "weights_proj", 1), + ] + stacked_params_mapping.extend(indexer_fused_mapping) + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f50e38b60..cfeb36f4a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -625,6 +625,11 @@ class Indexer(nn.Module): super().__init__() self.vllm_config = vllm_config self.config = config + self.quant_config = quant_config + self.is_fp4_ckpt = ( + self.quant_config is not None + and self.quant_config.get_name() == "modelopt_fp4" + ) # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] self.topk_tokens = config.index_topk self.n_head = config.index_n_heads # 64 @@ -639,18 +644,36 @@ class Indexer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.wq_b", ) - # Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. - # weights_proj does not get quantized, so we run both with quant_config=None - # wk may be upcasted from the default quant; experiments show fusion is always - # faster unless WK proj is in FP4, which is not the case for all known quants. - self.wk_weights_proj = MergedColumnParallelLinear( - hidden_size, - [self.head_dim, self.n_head], - bias=False, - quant_config=None, - disable_tp=True, - prefix=f"{prefix}.wk_weights_proj", - ) + if self.is_fp4_ckpt: + # Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. + # weights_proj does not get quantized, + # so we run both with quant_config=None + # wk may be upcasted from the default quant; + # experiments show fusion is always faster unless WK proj is in FP4, + # which is not the case for all known quants. + self.wk_weights_proj = MergedColumnParallelLinear( + hidden_size, + [self.head_dim, self.n_head], + bias=False, + quant_config=None, + disable_tp=True, + prefix=f"{prefix}.wk_weights_proj", + ) + else: + self.wk = ReplicatedLinear( + hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk", + ) + self.weights_proj = ReplicatedLinear( + hidden_size, + self.n_head, + bias=False, + quant_config=None, + prefix=f"{prefix}.weights_proj", + ) self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.softmax_scale = self.head_dim**-0.5 @@ -691,11 +714,14 @@ class Indexer(nn.Module): q_pe, q_nope = torch.split( q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 ) - - # Fused wk + weights_proj: one GEMM, then split - kw, _ = self.wk_weights_proj(hidden_states) - k = kw[:, : self.head_dim] - weights_raw = kw[:, self.head_dim :] + if self.is_fp4_ckpt: + # Fused wk + weights_proj: one GEMM, then split + kw, _ = self.wk_weights_proj(hidden_states) + k = kw[:, : self.head_dim] + weights = kw[:, self.head_dim :] + else: + k, _ = self.wk(hidden_states) + weights, _ = self.weights_proj(hidden_states) k = self.k_norm(k) k_pe, k_nope = torch.split( @@ -726,7 +752,7 @@ class Indexer(nn.Module): q_scale = q_scale.view(-1, self.n_head, 1) weights = ( - weights_raw.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 + weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 ) weights = weights.squeeze(-1) @@ -1314,6 +1340,10 @@ class DeepseekV2ForCausalLM( quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.is_fp4_ckpt = ( + self.quant_config is not None + and self.quant_config.get_name() == "modelopt_fp4" + ) qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) @@ -1439,12 +1469,13 @@ class DeepseekV2ForCausalLM( ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] - # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) - indexer_fused_mapping = [ - ("wk_weights_proj", "wk", 0), - ("wk_weights_proj", "weights_proj", 1), - ] - stacked_params_mapping.extend(indexer_fused_mapping) + if self.is_fp4_ckpt: + # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) + indexer_fused_mapping = [ + ("wk_weights_proj", "wk", 0), + ("wk_weights_proj", "weights_proj", 1), + ] + stacked_params_mapping.extend(indexer_fused_mapping) if self.use_mha: stacked_params_mapping.extend(mha_params_mapping)