[Attention] MLA move rotary embedding to cuda-graph region (#17668)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -453,7 +453,6 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
)
|
||||
|
||||
@@ -475,6 +474,13 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
|
||||
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
# Add head dim of 1 to k_pe
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
|
||||
positions, q[..., self.qk_nope_head_dim:], k_pe)
|
||||
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
|
||||
Reference in New Issue
Block a user