[Attention] MLA move rotary embedding to cuda-graph region (#17668)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-05-08 23:14:42 -04:00
committed by GitHub
parent 760e3ecc8f
commit 5e6f939484
6 changed files with 35 additions and 121 deletions

View File

@@ -453,7 +453,6 @@ class DeepseekV2MLAAttention(nn.Module):
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
kv_b_proj=self.kv_b_proj,
)
@@ -475,6 +474,13 @@ class DeepseekV2MLAAttention(nn.Module):
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q,
kv_c_normed,