Optimize QKNorm for MiniMax-M2/M2.1 (#31493)
Signed-off-by: xuebi <xuebi@minimaxi.com> Co-authored-by: xuebi <xuebi@minimaxi.com>
This commit is contained in:
@@ -234,8 +234,9 @@ class MiniMaxM2Attention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = MiniMaxText01RMSNormTP.forward_qk(
|
||||
self.q_norm, self.k_norm, q.contiguous(), k.contiguous()
|
||||
)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
Reference in New Issue
Block a user