[Attention] MLA support for V1 (#13789)

Signed-off-by: Yang Chen <yangche@fb.com>
This commit is contained in:
Yang Chen
2025-02-27 10:14:17 -08:00
committed by GitHub
parent f1579b229d
commit 58d1b2aa77
10 changed files with 1340 additions and 59 deletions

View File

@@ -420,9 +420,15 @@ class DeepseekV2MLAAttention(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
@@ -458,7 +464,10 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe)
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
class DeepseekV2DecoderLayer(nn.Module):