From 156405d243924fbede2d4360494a81dea7203334 Mon Sep 17 00:00:00 2001 From: Xiaoshuang Wang <1790571317@qq.com> Date: Sun, 5 Apr 2026 01:55:52 +0800 Subject: [PATCH] [vLLM IR] gemma_rms_norm (#38780) Signed-off-by: Icey <1790571317@qq.com> --- vllm/ir/ops/layernorm.py | 5 +- vllm/kernels/aiter_ops.py | 10 ++-- vllm/kernels/vllm_c.py | 7 ++- vllm/kernels/xpu_ops.py | 4 +- vllm/model_executor/layers/layernorm.py | 66 +++++-------------------- 5 files changed, 25 insertions(+), 67 deletions(-) diff --git a/vllm/ir/ops/layernorm.py b/vllm/ir/ops/layernorm.py index 8471aa043..ac0c38a9e 100644 --- a/vllm/ir/ops/layernorm.py +++ b/vllm/ir/ops/layernorm.py @@ -16,7 +16,6 @@ def rms_norm( x_var = x if variance_size is None else x[..., :variance_size] variance = x_var.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + epsilon) - x = x.to(orig_dtype) if weight is not None: - x = x * weight - return x + x = x.to(weight.dtype) * weight + return x.to(orig_dtype) diff --git a/vllm/kernels/aiter_ops.py b/vllm/kernels/aiter_ops.py index 1980051dd..14c2b87fb 100644 --- a/vllm/kernels/aiter_ops.py +++ b/vllm/kernels/aiter_ops.py @@ -36,13 +36,11 @@ AITER_SUPPORTED = is_aiter_found() rms_no_var_16bit_only = ( lambda x, weight, epsilon, variance_size=None: variance_size is None - and x.dtype - in ( - torch.float16, - torch.bfloat16, - ) + and x.dtype in (torch.float16, torch.bfloat16) + and (weight is None or weight.dtype == x.dtype) ) -"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override.""" +"""AITER rms_norm only supports float16 and bfloat16 acts, no var_size override, +and requires weight dtype to match x dtype.""" @ir.ops.rms_norm.register_impl( diff --git a/vllm/kernels/vllm_c.py b/vllm/kernels/vllm_c.py index fabb36d7b..fab91de2e 100644 --- a/vllm/kernels/vllm_c.py +++ b/vllm/kernels/vllm_c.py @@ -11,8 +11,11 @@ current_platform.import_kernels() CUDA_ALIKE = current_platform.is_cuda_alike() """Most kernels in this file are supported on all CUDA-alike platforms.""" -rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None -"""vLLM kernel does not support variance_size parameter.""" +rms_no_var_size = ( + lambda x, weight, epsilon, variance_size=None: variance_size is None + and (weight is None or weight.dtype == x.dtype) +) +"""vLLM kernel does not support variance_size parameter or mismatched weight dtype.""" @ir.ops.rms_norm.register_impl( diff --git a/vllm/kernels/xpu_ops.py b/vllm/kernels/xpu_ops.py index 3548fb868..c680c542c 100644 --- a/vllm/kernels/xpu_ops.py +++ b/vllm/kernels/xpu_ops.py @@ -18,7 +18,9 @@ def is_xpu_kernels_found() -> bool: XPU_KERNELS_SUPPORTED = is_xpu_kernels_found() """Kernels in this file are supported if vLLM XPU kernels are installed.""" -rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None +rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None and ( + weight is None or weight.dtype == x.dtype +) @ir.ops.rms_norm.register_impl( diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 7b222f9c4..df22f3102 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -376,46 +376,6 @@ class GemmaRMSNorm(CustomOp): self.weight = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps - @staticmethod - def _forward_static_no_residual( - weight: torch.Tensor, - variance_epsilon: float, - x: torch.Tensor, - ) -> torch.Tensor: - """PyTorch-native implementation equivalent to forward() without residual.""" - orig_dtype = x.dtype - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + variance_epsilon) - x = x * (1.0 + weight.float()) - x = x.to(orig_dtype) - return x - - @staticmethod - def _forward_static_with_residual( - weight: torch.Tensor, - variance_epsilon: float, - x: torch.Tensor, - residual: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """PyTorch-native implementation equivalent to forward() with residual.""" - orig_dtype = x.dtype - x = ( - x.float() + residual.float() - if orig_dtype == torch.float16 - else x + residual - ) - residual = x - - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + variance_epsilon) - # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - x = x * (1.0 + weight.float()) - x = x.to(orig_dtype) - return x, residual - def forward_native( self, x: torch.Tensor, @@ -423,30 +383,26 @@ class GemmaRMSNorm(CustomOp): ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" if residual is None: - return self._forward_static_no_residual( - self.weight.data, self.variance_epsilon, x + return ir.ops.rms_norm( + x, self.weight.data.float() + 1.0, self.variance_epsilon ) else: - return self._forward_static_with_residual( - self.weight.data, self.variance_epsilon, x, residual + orig_dtype = x.dtype + x = ( + x.float() + residual.float() + if orig_dtype == torch.float16 + else x + residual ) + residual = x + return ir.ops.rms_norm( + x, self.weight.data.float() + 1.0, self.variance_epsilon + ).to(orig_dtype), residual def forward_cuda( self, x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if torch.compiler.is_compiling(): - return self.forward_native(x, residual) - - if not getattr(self, "_is_compiled", False): - self._forward_static_no_residual = torch.compile( # type: ignore - self._forward_static_no_residual - ) - self._forward_static_with_residual = torch.compile( # type: ignore - self._forward_static_with_residual - ) - self._is_compiled = True return self.forward_native(x, residual)