Optimize QKNorm for MiniMax-M2/M2.1 (#31493)

Signed-off-by: xuebi <xuebi@minimaxi.com>
Co-authored-by: xuebi <xuebi@minimaxi.com>
This commit is contained in:
Roger Young
2025-12-30 00:30:18 +08:00
committed by GitHub
parent b3a2bdf1ac
commit 5bc664110f
2 changed files with 25 additions and 2 deletions

View File

@@ -79,6 +79,28 @@ class MiniMaxText01RMSNormTP(CustomOp):
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
@staticmethod
def forward_qk(
q_norm: "MiniMaxText01RMSNormTP",
k_norm: "MiniMaxText01RMSNormTP",
q: torch.Tensor,
k: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
orig_dtype = q.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
q_var = q.pow(2).mean(dim=-1, keepdim=True)
k_var = k.pow(2).mean(dim=-1, keepdim=True)
if q_norm.tp_world > 1:
qk_var = torch.cat([q_var, k_var], dim=-1)
qk_var = tensor_model_parallel_all_reduce(qk_var) / q_norm.tp_world
q_var, k_var = qk_var.chunk(2, dim=-1)
q = q * torch.rsqrt(q_var + q_norm.variance_epsilon) * q_norm.weight
k = k * torch.rsqrt(k_var + k_norm.variance_epsilon) * k_norm.weight
q = q.to(orig_dtype)
k = k.to(orig_dtype)
return q, k
class MiniMaxText01LinearKernel:
@staticmethod

View File

@@ -234,8 +234,9 @@ class MiniMaxM2Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k = MiniMaxText01RMSNormTP.forward_qk(
self.q_norm, self.k_norm, q.contiguous(), k.contiguous()
)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)