diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 278713408..8020efbe3 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -79,6 +79,28 @@ class MiniMaxText01RMSNormTP(CustomOp): assert residual is None, "RMSNorm does not support residual connection." return self._forward(x) + @staticmethod + def forward_qk( + q_norm: "MiniMaxText01RMSNormTP", + k_norm: "MiniMaxText01RMSNormTP", + q: torch.Tensor, + k: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + orig_dtype = q.dtype + q = q.to(torch.float32) + k = k.to(torch.float32) + q_var = q.pow(2).mean(dim=-1, keepdim=True) + k_var = k.pow(2).mean(dim=-1, keepdim=True) + if q_norm.tp_world > 1: + qk_var = torch.cat([q_var, k_var], dim=-1) + qk_var = tensor_model_parallel_all_reduce(qk_var) / q_norm.tp_world + q_var, k_var = qk_var.chunk(2, dim=-1) + q = q * torch.rsqrt(q_var + q_norm.variance_epsilon) * q_norm.weight + k = k * torch.rsqrt(k_var + k_norm.variance_epsilon) * k_norm.weight + q = q.to(orig_dtype) + k = k.to(orig_dtype) + return q, k + class MiniMaxText01LinearKernel: @staticmethod diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index ee19288ae..822bf9b5c 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -234,8 +234,9 @@ class MiniMaxM2Attention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = self.q_norm(q) - k = self.k_norm(k) + q, k = MiniMaxText01RMSNormTP.forward_qk( + self.q_norm, self.k_norm, q.contiguous(), k.contiguous() + ) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output)