Optimize QKNorm for MiniMax-M2/M2.1 (#31493)

Signed-off-by: xuebi <xuebi@minimaxi.com>
Co-authored-by: xuebi <xuebi@minimaxi.com>
This commit is contained in:
Roger Young
2025-12-30 00:30:18 +08:00
committed by GitHub
parent b3a2bdf1ac
commit 5bc664110f
2 changed files with 25 additions and 2 deletions

View File

@@ -234,8 +234,9 @@ class MiniMaxM2Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k = MiniMaxText01RMSNormTP.forward_qk(
self.q_norm, self.k_norm, q.contiguous(), k.contiguous()
)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)