[perf] Add fused MLA QKV + strided layernorm (#21116)
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -42,6 +42,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
MergedReplicatedLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -336,7 +337,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
kv_a, _ = latent_cache.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = latent_cache.unsqueeze(1)
|
||||
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
||||
kv_a = self.kv_a_layernorm(kv_a)
|
||||
kv = self.kv_b_proj(kv_a)[0]
|
||||
kv = kv.view(-1, self.num_local_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim)
|
||||
@@ -407,14 +408,24 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj")
|
||||
self.fused_qkv_a_proj = MergedReplicatedLinear(
|
||||
self.hidden_size,
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_qkv_a_proj")
|
||||
else:
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
||||
self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
@@ -427,13 +438,6 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
@@ -495,15 +499,24 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
q_c = None
|
||||
kv_lora = None
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
q_c = self.q_a_proj(hidden_states)[0]
|
||||
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
||||
q_c, kv_lora = qkv_lora.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
q = self.q_b_proj(q_c)[0]
|
||||
else:
|
||||
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
q = self.q_proj(hidden_states)[0]
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
|
||||
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c)
|
||||
|
||||
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
# Add head dim of 1 to k_pe
|
||||
@@ -837,6 +850,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
@@ -871,6 +886,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
# QKV fusion is optional, fall back to normal
|
||||
# weight loading if it's not enabled
|
||||
if ((param_name == "fused_qkv_a_proj")
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user