Optimize QKNorm for MiniMax-M2/M2.1 (#31493)
Signed-off-by: xuebi <xuebi@minimaxi.com> Co-authored-by: xuebi <xuebi@minimaxi.com>
This commit is contained in:
@@ -79,6 +79,28 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
assert residual is None, "RMSNorm does not support residual connection."
|
||||
return self._forward(x)
|
||||
|
||||
@staticmethod
|
||||
def forward_qk(
|
||||
q_norm: "MiniMaxText01RMSNormTP",
|
||||
k_norm: "MiniMaxText01RMSNormTP",
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_dtype = q.dtype
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
q_var = q.pow(2).mean(dim=-1, keepdim=True)
|
||||
k_var = k.pow(2).mean(dim=-1, keepdim=True)
|
||||
if q_norm.tp_world > 1:
|
||||
qk_var = torch.cat([q_var, k_var], dim=-1)
|
||||
qk_var = tensor_model_parallel_all_reduce(qk_var) / q_norm.tp_world
|
||||
q_var, k_var = qk_var.chunk(2, dim=-1)
|
||||
q = q * torch.rsqrt(q_var + q_norm.variance_epsilon) * q_norm.weight
|
||||
k = k * torch.rsqrt(k_var + k_norm.variance_epsilon) * k_norm.weight
|
||||
q = q.to(orig_dtype)
|
||||
k = k.to(orig_dtype)
|
||||
return q, k
|
||||
|
||||
|
||||
class MiniMaxText01LinearKernel:
|
||||
@staticmethod
|
||||
|
||||
@@ -234,8 +234,9 @@ class MiniMaxM2Attention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = MiniMaxText01RMSNormTP.forward_qk(
|
||||
self.q_norm, self.k_norm, q.contiguous(), k.contiguous()
|
||||
)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
Reference in New Issue
Block a user