[Perf] DSV3.2 Indexer Fused Weights Projection (#38684)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-04-01 23:34:49 -04:00
committed by GitHub
parent 694449050f
commit 5f96f9aff1
2 changed files with 25 additions and 14 deletions

View File

@@ -241,6 +241,9 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
# Fused indexer wk + weights_proj
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(

View File

@@ -639,21 +639,19 @@ class Indexer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
)
self.wk = ReplicatedLinear(
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# weights_proj does not get quantized, so we run both with quant_config=None
# wk may be upcasted from the default quant; experiments show fusion is always
# faster unless WK proj is in FP4, which is not the case for all known quants.
self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.weights_proj = ReplicatedLinear(
hidden_size,
self.n_head,
[self.head_dim, self.n_head],
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
disable_tp=True,
prefix=f"{prefix}.wk_weights_proj",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = "ue8m0"
@@ -694,7 +692,11 @@ class Indexer(nn.Module):
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)
k, _ = self.wk(hidden_states)
# Fused wk + weights_proj: one GEMM, then split
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :]
k = self.k_norm(k)
k_pe, k_nope = torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
@@ -723,9 +725,8 @@ class Indexer(nn.Module):
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states)
weights = (
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights_raw.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
)
weights = weights.squeeze(-1)
@@ -1438,6 +1439,13 @@ class DeepseekV2ForCausalLM(
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)
if self.use_mha:
stacked_params_mapping.extend(mha_params_mapping)
else: