[Attention] MLA move o_proj q_proj into cuda-graph region (#17484)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -454,9 +454,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
@@ -468,17 +466,22 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
q_c = self.q_a_proj(hidden_states)[0]
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
q = self.q_b_proj(q_c)[0]
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
q = self.q_proj(hidden_states)[0]
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=hidden_states.shape)
|
||||
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=(hidden_states.shape[0],
|
||||
self.num_local_heads * self.v_head_dim))
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user