From 5f96f9aff10fec0c31eb61f7587b3789400f9c17 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Wed, 1 Apr 2026 23:34:49 -0400 Subject: [PATCH] [Perf] DSV3.2 Indexer Fused Weights Projection (#38684) Signed-off-by: Benjamin Chislett --- vllm/model_executor/models/deepseek_mtp.py | 3 ++ vllm/model_executor/models/deepseek_v2.py | 36 +++++++++++++--------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index c75ee1a1b..87364b1f8 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -241,6 +241,9 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ("gate_up_proj", "up_proj", 1), ("fused_qkv_a_proj", "q_a_proj", 0), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + # Fused indexer wk + weights_proj + ("wk_weights_proj", "wk", 0), + ("wk_weights_proj", "weights_proj", 1), ] expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f1c4a7b21..f50e38b60 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -639,21 +639,19 @@ class Indexer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.wq_b", ) - self.wk = ReplicatedLinear( + # Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. + # weights_proj does not get quantized, so we run both with quant_config=None + # wk may be upcasted from the default quant; experiments show fusion is always + # faster unless WK proj is in FP4, which is not the case for all known quants. + self.wk_weights_proj = MergedColumnParallelLinear( hidden_size, - self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wk", - ) - self.k_norm = LayerNorm(self.head_dim, eps=1e-6) - self.weights_proj = ReplicatedLinear( - hidden_size, - self.n_head, + [self.head_dim, self.n_head], bias=False, quant_config=None, - prefix=f"{prefix}.weights_proj", + disable_tp=True, + prefix=f"{prefix}.wk_weights_proj", ) + self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.softmax_scale = self.head_dim**-0.5 self.scale_fmt = "ue8m0" @@ -694,7 +692,11 @@ class Indexer(nn.Module): q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 ) - k, _ = self.wk(hidden_states) + # Fused wk + weights_proj: one GEMM, then split + kw, _ = self.wk_weights_proj(hidden_states) + k = kw[:, : self.head_dim] + weights_raw = kw[:, self.head_dim :] + k = self.k_norm(k) k_pe, k_nope = torch.split( k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 @@ -723,9 +725,8 @@ class Indexer(nn.Module): q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) q_scale = q_scale.view(-1, self.n_head, 1) - weights, _ = self.weights_proj(hidden_states) weights = ( - weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 + weights_raw.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 ) weights = weights.squeeze(-1) @@ -1438,6 +1439,13 @@ class DeepseekV2ForCausalLM( ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] + # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) + indexer_fused_mapping = [ + ("wk_weights_proj", "wk", 0), + ("wk_weights_proj", "weights_proj", 1), + ] + stacked_params_mapping.extend(indexer_fused_mapping) + if self.use_mha: stacked_params_mapping.extend(mha_params_mapping) else: